[FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
pull/5976/head
flybird11111 2024-08-06 16:29:37 +08:00 committed by GitHub
parent 53cb9606bd
commit 0c10afd372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
208 changed files with 10962 additions and 2892 deletions

View File

@ -1 +1,3 @@
2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0

View File

@ -55,41 +55,27 @@ jobs:
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -49,42 +49,27 @@ jobs:
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -43,47 +43,28 @@ jobs:
steps:
- name: Install dependencies
run: |
apt update && apt install -y cmake
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -52,6 +52,7 @@ jobs:
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
./tests/test_data_preparation.sh
./tests/test_train.sh
env:
@ -61,3 +62,4 @@ jobs:
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data

View File

@ -10,7 +10,7 @@ import math
import os
from multiprocessing import cpu_count
from colossal_llama.dataset.conversation import LLaMA2_Conv
from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AddedToken, AutoTokenizer
@ -75,6 +75,8 @@ def main():
# Prepare to the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
default_conversation = LLaMA3_Conv
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)

View File

@ -128,6 +128,12 @@ def main() -> None:
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
parser.add_argument(
"--skip_save_each_epoch",
action="store_true",
default=False,
help="skip saving the model checkpoint after each epoch is completed.",
)
args = parser.parse_args()
with open(args.config_file, "w") as f:
@ -370,11 +376,17 @@ def main() -> None:
)
total_loss.fill_(0.0)
pbar.update()
# Save modeling.
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
step + 1
) == len(dataloader):
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:

View File

@ -146,6 +146,9 @@ docs/.build
examples/wandb/
examples/logs/
examples/output/
examples/training_scripts/logs
examples/training_scripts/wandb
examples/training_scripts/output
examples/awesome-chatgpt-prompts/
temp/

View File

@ -23,6 +23,10 @@
- [Open QA](#open-qa)
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
@ -135,17 +139,15 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
...
]
},
...
]
```
@ -171,23 +173,20 @@ Below shows the preference dataset format used in training the reward model.
"from": "human",
"content": "Introduce butterflies species in Oregon."
}
]
],
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
...
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
...
]
},
...
]
```
@ -216,7 +215,6 @@ PPO uses two kind of training data--- the prompt data and the sft data (optional
"from": "human",
"content": "what are some pranks with a pen i can do?"
}
...
]
},
]
@ -262,9 +260,8 @@ experience buffer size
= train_batch_size * accumulation_steps * num_tp_group
```
## Alternative Option For RLHF: Direct Preference Optimization
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
### DPO Training Stage1 - Supervised Instructs Tuning
@ -277,6 +274,15 @@ For DPO training, you only need the preference dataset. Please follow the instru
#### Step 2: Training
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
### Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@ -441,20 +447,6 @@ If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
</details>
## The Plan
- [x] implement PPO fine-tuning
- [x] implement training reward model
- [x] support LoRA
- [x] support inference
- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
- [x] implement PPO-ptx fine-tuning
- [x] support flash-attention
- [x] implement DPO fine-tuning
- [ ] integrate with Ray
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
### Real-time progress
You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
@ -522,7 +514,7 @@ Coati is developed by ColossalAI Team:
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
@ -572,6 +564,36 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github
journal = {GitHub repository},
howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}},
}
@misc{meng2024simposimplepreferenceoptimization,
title={SimPO: Simple Preference Optimization with a Reference-Free Reward},
author={Yu Meng and Mengzhou Xia and Danqi Chen},
year={2024},
eprint={2405.14734},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2405.14734},
}
@misc{rafailov2023directpreferenceoptimizationlanguage,
title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
year={2023},
eprint={2305.18290},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2305.18290},
}
@misc{hong2024orpomonolithicpreferenceoptimization,
title={ORPO: Monolithic Preference Optimization without Reference Model},
author={Jiwoo Hong and Noah Lee and James Thorne},
year={2024},
eprint={2403.07691},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2403.07691},
}
```
## Licenses

View File

@ -0,0 +1,51 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
DATASET_SIZE=320
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 4 \
--lr 1e-6 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,51 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
DATASET_SIZE=80
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto
colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 2 \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,51 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="orpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data
DATASET_SIZE=160
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 4 \
--lr 8e-6 \
--lam 0.5 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,50 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="sft"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data
DATASET_SIZE=640
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type sft
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 1 --master_port 31312 ../examples/training_scripts/train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
--accumulation_steps 1 \
--lr 5e-5 \
--lora_rank 32 \
--max_len 2048 \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,55 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="simpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data
DATASET_SIZE=640
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--loss_type "simpo_loss" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--lr 1e-6 \
--beta 0.1 \
--gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--disable_reference_model \
--length_normalization \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,30 @@
from typing import Callable
from torch.utils.data import Dataset
class DummyLLMDataset(Dataset):
def __init__(self, keys, seq_len, size=500, gen_fn={}):
self.keys = keys
self.gen_fn = gen_fn
self.seq_len = seq_len
self.data = self._generate_data()
self.size = size
def _generate_data(self):
data = {}
for key in self.keys:
if key in self.gen_fn:
data[key] = self.gen_fn[key]
else:
data[key] = [1] * self.seq_len
return data
def __len__(self):
return self.size
def __getitem__(self, idx):
return {
key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
for key in self.keys
}

View File

@ -0,0 +1,105 @@
import argparse
import json
import os
import time
from multiprocessing import cpu_count
from datasets import load_dataset
from dummy_dataset import DummyLLMDataset
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
required=True,
default=None,
help="The output dir",
)
parser.add_argument(
"--dataset_size",
type=int,
required=True,
default=None,
help="The size of data",
)
parser.add_argument(
"--max_length",
type=int,
required=True,
default=None,
help="The max length of data",
)
parser.add_argument(
"--data_type",
type=str,
required=True,
default=None,
help="The type of data, choose one from ['sft', 'prompt', 'preference', 'kto']",
)
args = parser.parse_args()
if args.data_type == "sft":
dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_length, args.dataset_size)
elif args.data_type == "prompt":
# pass PPO dataset is prepared separately
pass
elif args.data_type == "preference":
dataset = DummyLLMDataset(
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
args.max_length,
args.dataset_size,
)
elif args.data_type == "kto":
dataset = DummyLLMDataset(
["prompt", "completion", "label"],
args.max_length - 512,
args.dataset_size,
gen_fn={
"completion": lambda x: [1] * 512,
"label": lambda x: x % 2,
},
)
else:
raise ValueError(f"Unknown data type {args.data_type}")
# Save each jsonl spliced dataset.
output_index = "0"
output_name = f"part-{output_index}"
os.makedirs(args.data_dir, exist_ok=True)
output_jsonl_path = os.path.join(args.data_dir, "json")
output_arrow_path = os.path.join(args.data_dir, "arrow")
output_cache_path = os.path.join(args.data_dir, "cache")
os.makedirs(output_jsonl_path, exist_ok=True)
os.makedirs(output_arrow_path, exist_ok=True)
output_jsonl_file_path = os.path.join(output_jsonl_path, output_name + ".jsonl")
st = time.time()
with open(file=output_jsonl_file_path, mode="w", encoding="utf-8") as fp_writer:
count = 0
for i in range(len(dataset)):
data_point = dataset[i]
if count % 500 == 0:
logger.info(f"processing {count} spliced data points for {fp_writer.name}")
count += 1
fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
logger.info(
f"Current file {fp_writer.name}; "
f"Data size: {len(dataset)}; "
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
)
# Save each arrow spliced dataset
output_arrow_file_path = os.path.join(output_arrow_path, output_name)
logger.info(f"Start to save {output_arrow_file_path}")
dataset = load_dataset(
path="json",
data_files=[output_jsonl_file_path],
cache_dir=os.path.join(output_cache_path, "tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
dataset.save_to_disk(dataset_path=output_arrow_file_path, num_proc=min(len(dataset), cpu_count()))

View File

@ -1,24 +1,26 @@
from .conversation import Conversation, setup_conversation_template
from .loader import (
DataCollatorForKTODataset,
DataCollatorForPreferenceDataset,
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
)
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
__all__ = [
"tokenize_prompt_dataset",
"tokenize_prompt",
"DataCollatorForPromptDataset",
"is_rank_0",
"DataCollatorForPreferenceDataset",
"DataCollatorForSupervisedDataset",
"DataCollatorForKTODataset",
"StatefulDistributedSampler",
"load_tokenized_dataset",
"supervised_tokenize_pretrain",
"supervised_tokenize_sft",
"tokenize_sft",
"tokenize_rlhf",
"tokenize_kto",
"setup_conversation_template",
"Conversation",
]

View File

@ -18,6 +18,7 @@ class Conversation:
chat_template: str
stop_ids: List[int]
end_of_assistant: str
roles = ["user", "assistant"]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@ -85,7 +86,7 @@ class Conversation:
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
assert role in ["user", "assistant"]
assert role in self.roles
self.messages.append({"role": role, "content": message})
def copy(self):

View File

@ -28,6 +28,8 @@ def load_tokenized_dataset(
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
if not dataset_paths:
return None
mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
@ -233,6 +235,91 @@ class DataCollatorForPreferenceDataset(object):
)
@dataclass
class DataCollatorForKTODataset(object):
"""
Collate instances for kto dataset.
Each input instance is a tokenized dictionary with fields
`prompt`(List[int]), `completion`(List[int]) and `label`(bool).
Each output instance is a tokenized dictionary with fields
`kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
`input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
`prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# prepare the preference data
prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
prompt_zeros = [torch.zeros_like(t) for t in prompt]
completion = [torch.LongTensor(instance["completion"]) for instance in instances]
completion_ones = [torch.ones_like(t) for t in completion]
label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
# right padding
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
loss_mask = torch.nn.utils.rnn.pad_sequence(
sequences=loss_mask, batch_first=True, padding_value=0
) # (bsz, max_len)
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
# prepare kt data
kl_completion = completion[::-1] # y'
kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
# right padding
kl_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=kl_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
sequences=kl_loss_mask, batch_first=True, padding_value=0
) # (bsz, max_len)
to_pad = self.max_length - kl_input_ids.size(1)
kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
data_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"label": torch.stack(label),
"kl_input_ids": kl_input_ids,
"kl_attention_mask": kl_attention_mask,
"kl_loss_mask": kl_loss_mask,
}
return data_dict
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,

View File

@ -23,11 +23,10 @@ IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize_sft(
def tokenize_sft(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@ -39,51 +38,37 @@ def supervised_tokenize_sft(
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 0:
# Force to end with assistant response
template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
turns = [i for i in range(1, len(messages) // 2 + 1)]
lo, hi = 0, len(turns)
while lo < hi:
mid = (lo + hi) // 2
if max_length - 1 < len(
tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
):
hi = mid
else:
lo = mid + 1
target_turn_index = lo
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
# tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
prompt = template.get_prompt()
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages, prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
if tokenized is None:
return dict(
input_ids=None,
labels=None,
@ -93,44 +78,18 @@ def supervised_tokenize_sft(
seq_category=None,
)
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start:end] = tokenized[start:end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
for i in range(len(tokenized) - 1, -1, -1):
if labels[i] == ignore_index:
to_truncate_len += 1
else:
break
tokenized = tokenized[: len(tokenized) - to_truncate_len]
labels = labels[: len(labels) - to_truncate_len]
if tokenizer.bos_token_id is not None:
# Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
if tokenized[0] != tokenizer.bos_token_id:
# Some chat templates already include bos token
tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] + labels
labels = [-100] + labels
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [tokenizer.eos_token_id]
else:
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
# log decoded inputs and labels for debugging
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
@ -167,11 +126,10 @@ def supervised_tokenize_sft(
)
def tokenize_prompt_dataset(
def tokenize_prompt(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@ -179,48 +137,39 @@ def tokenize_prompt_dataset(
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
target_turn = len(template.messages)
if target_turn % 2 != 1:
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
target_turn = target_turn - 1
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
# Skip overlength data
if max_length - 1 < len(tokenized):
if len(tokenized) > max_length:
return dict(
input_ids=None,
inputs_decode=None,
@ -231,47 +180,32 @@ def tokenize_prompt_dataset(
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=tokenizer.decode(tokenized),
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
):
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
if mask_token is None:
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
# no truncation applied
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
loss_mask = [0] * len(tokenized)
label_decode = []
for start, end in zip(starts[-1:], ends[-1:]):
# only the last round (chosen/rejected) counts
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start:end] = [1] * len(loss_mask[start:end])
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
# only the last round (chosen/rejected) is used to calculate loss
for i in range(starts[-1], ends[-1]):
loss_mask[i] = 1
label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
else:
loss_mask[-1] = 1
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
@ -279,39 +213,29 @@ def tokenize_rlhf(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
context = data_point["context"]
template = deepcopy(conversation_template)
template.clear()
for mess in context:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
# Concate adjacent message from the same role
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
else:
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{context}"
)
template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 1:
warnings.warn(
"Please make sure leading context starts and ends with a line from human\nLeading context: "
"Please make sure leading context starts and ends with a line from user\nLeading context: "
+ str(template.messages)
)
return dict(
@ -322,31 +246,27 @@ def tokenize_rlhf(
rejected_loss_mask=None,
rejected_label_decode=None,
)
round_of_context = int((len(template.messages) - 1) / 2)
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
chosen = deepcopy(template)
rejected = deepcopy(template)
chosen_continuation = data_point["chosen"]
rejected_continuation = data_point["rejected"]
for round in range(len(chosen_continuation)):
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{chosen_continuation}"
)
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
for round in range(len(data_point["chosen"])):
from_str = data_point["chosen"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
chosen.append_message(from_str, data_point["chosen"][round]["content"])
for round in range(len(data_point["rejected"])):
from_str = data_point["rejected"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
rejected.append_message(from_str, data_point["rejected"][round]["content"])
for round in range(len(rejected_continuation)):
if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{rejected_continuation}"
)
rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
(
chosen_input_ids,
@ -356,48 +276,22 @@ def tokenize_rlhf(
rejected_loss_mask,
rejected_label_decode,
) = (None, None, None, None, None, None)
if (
len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
<= max_length - 1
and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
<= max_length - 1
):
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
rejected_data_packed = apply_rlhf_data_format(
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
rejected_data_packed["label_decode"],
)
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
chosen_label_decode=None,
rejected_input_ids=None,
rejected_loss_mask=None,
rejected_label_decode=None,
)
rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
rejected_data_packed["label_decode"],
)
return {
"chosen_input_ids": chosen_input_ids,
"chosen_loss_mask": chosen_loss_mask,
"chosen_label_decode": chosen_label_decode,
"rejected_input_ids": rejected_input_ids,
"rejected_loss_mask": rejected_loss_mask,
"rejected_label_decode": rejected_label_decode,
}
else:
if len(chosen_input_ids) > max_length or len(rejected_input_ids) > max_length:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
@ -406,3 +300,81 @@ def tokenize_rlhf(
rejected_loss_mask=None,
rejected_label_decode=None,
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
chosen_label_decode=None,
rejected_input_ids=None,
rejected_loss_mask=None,
rejected_label_decode=None,
)
return {
"chosen_input_ids": chosen_input_ids,
"chosen_loss_mask": chosen_loss_mask,
"chosen_label_decode": chosen_label_decode,
"rejected_input_ids": rejected_input_ids,
"rejected_loss_mask": rejected_loss_mask,
"rejected_label_decode": rejected_label_decode,
}
def tokenize_kto(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize a dataset for KTO training
The raw input data is conversation that have the following format
{
"prompt": [{"from": "user", "content": "xxx"}...],
"completion": {"from": "assistant", "content": "xxx"},
"label": true/false
}
It returns three fields
The context, which contain the query and the assistant start,
the completion, which only contains the assistance's answer,
and a binary label, which indicates if the sample is prefered or not
"""
prompt = data_point["prompt"]
completion = data_point["completion"]
template = deepcopy(conversation_template)
template.clear()
if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
raise ValueError("conversation should end with assistant")
for mess in prompt:
if mess.get("from", None) == "user":
template.append_message("user", mess["content"])
elif mess.get("from", None) == "assistant":
template.append_message("assistant", mess["content"])
else:
raise ValueError(f"Unsupported role {mess.get('from', None)}")
generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True)
template.append_message("assistant", completion["content"])
full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False)
tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
if len(tokenized_full_prompt) + 1 > max_length:
return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None)
tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"]
tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :]
tokenized_completion = deepcopy(tokenized_completion)
if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id:
tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt
decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False)
decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False)
return {
"prompt": tokenized_generation_prompt,
"completion": tokenized_completion,
"label": data_point["label"],
"input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion,
}

View File

@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
return -1
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
def tokenize_and_concatenate(
tokenizer: PreTrainedTokenizer,
text: List[str],
require_loss: List[bool],
max_length: int,
discard_non_loss_tokens_at_tail: bool = True,
):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
max_length: used to truncate the input ids
discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
if the first round has already exeeded max length
- if the user query already exeeded max length, discard the sample
- if only the first assistant response exeeded max length, truncate the response to fit the max length
else keep the first several complete rounds of the conversations until max length is reached
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
@ -106,10 +119,18 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if max_length and loss_starts[0] >= max_length:
return None, None, None
if discard_non_loss_tokens_at_tail:
input_ids = input_ids[: loss_ends[-1]]
if max_length:
input_ids = input_ids[:max_length]
loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends
@ -125,6 +146,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
)
# if the tokenized content start with a leading space, we want to keep it in loss calculation
# e.g., Assistant: I am saying...
# if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
# e.g.,
# Assistant: # '\n' as line breaker
# I am saying...
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length])

View File

@ -1,8 +1,8 @@
from .base import BaseModel
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import convert_to_lora_module
from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout
@ -14,9 +14,11 @@ __all__ = [
"ValueLoss",
"LogSigLoss",
"LogExpLoss",
"LoraConfig",
"lora_manager",
"convert_to_lora_module",
"DpoLoss",
"generate",
"KTOLoss" "generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",

View File

@ -42,7 +42,6 @@ class BaseModel(nn.Module):
out = self.model(dummy_input)
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
self.model = self.model.cpu()
# print("self.last_hidden_state_size: ",self.last_hidden_state_size)
def resize_token_embeddings(self, *args, **kwargs):
"""

View File

@ -5,10 +5,11 @@ LORA utils
import dataclasses
import math
import warnings
from typing import Optional
from typing import List, Optional, Union
import loralib as lora
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
@ -18,148 +19,349 @@ logger = get_dist_logger()
@dataclasses.dataclass
class LoRAManager:
merge_weights: bool = False
class LoraManager:
able_to_merge: bool = True
LORA_MANAGER = LoRAManager()
lora_manager = LoraManager()
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
@dataclasses.dataclass
class LoraConfig:
r: int = 0
lora_alpha: int = 32
linear_lora_dropout: float = 0.1
embedding_lora_dropout: float = 0.0
lora_train_bias: str = "none"
lora_initialization_method: str = "kaiming_uniform"
target_modules: List = None
@classmethod
def from_file(cls, config_file: str):
import json
with open(config_file, "r") as f:
config = json.load(f)
return cls(**config)
class LoraBase(lora.LoRALayer, nn.Module):
def __init__(
self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
lora_initialization_method: str = "kaiming_uniform",
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight
self.bias = bias
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout = nn.Dropout(lora_dropout)
self.merged = False
self.lora_initialization_method = lora_initialization_method
self.weight = None
self.bias = None
self.lora_A = None
self.lora_B = None
def reset_parameters(self):
if hasattr(self, "lora_A"):
# Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
self.out_features,
self.in_features,
):
# Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
elif self.lora_initialization_method == "PiSSA":
# PiSSA method in this paper: https://arxiv.org/abs/2404.02948
# Assume the SVD of the original weights is W = USV^T
# Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
# Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
# self.scaling = 1.
# SVD
U, S, Vh = torch.svd_lowrank(
self.weight.to(torch.float32).data, self.r, niter=4
) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
# weight_backup = self.weight.clone()
# Initialize A, B
S = S / self.scaling
self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
# Initialize weight
# To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
self.weight.data = (
((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
)
self.lora_A.requires_grad = True
self.lora_B.requires_grad = True
else:
raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
def train(self, mode: bool = True):
"""
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
"""
def T(w):
return w.T if self.fan_in_fan_out else w
self.training = mode
if LORA_MANAGER.merge_weights:
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
elif not mode and not self.merged:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
elif not mode and not self.merged and lora_manager.able_to_merge:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += self.lora_B @ self.lora_A * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
return self
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
class LoraLinear(LoraBase):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
weight: nn.Parameter,
bias: Union[nn.Parameter, bool],
r: int = 0,
lora_alpha: int = 32,
lora_dropout: float = 0.0,
lora_initialization_method: str = "kaiming_uniform",
):
super().__init__(
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
)
self.weight = weight
self.bias = bias
if bias is True:
self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
if bias is not None:
self.bias.requires_grad = True
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
self.lora_initialization_method = lora_initialization_method
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
result = F.linear(x, self.weight, bias=self.bias)
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
return F.linear(x, self.weight, bias=self.bias)
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
class LoraEmbedding(LoraBase):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
weight: nn.Parameter,
r: int = 0,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
num_embeddings: int = None,
embedding_dim: int = None,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
lora_initialization_method: str = "kaiming_uniform",
):
super().__init__(
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
)
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = weight
in_features, out_features = num_embeddings, embedding_dim
self.in_features = in_features
self.out_features = out_features
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
self.lora_initialization_method = lora_initialization_method
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# reset parameters
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
return F.embedding(
x,
weight,
padding_idx=self.padding_idx,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse,
)
def forward(self, x: torch.Tensor):
base_embedding = self._embed(x, self.weight)
# base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
if self.r > 0 and not self.merged:
lora_A_embedding = self._embed(x, self.lora_A.t())
embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
return embedding
else:
return base_embedding
def train(self, mode: bool = True):
"""
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
"""
self.training = mode
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
elif not mode and not self.merged and lora_manager.able_to_merge:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
return self
def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
"""
Wraps a linear layer with LoRA functionality.
Args:
linear (nn.Linear): The linear layer to be wrapped.
lora_rank (int): The rank of the LoRA decomposition.
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
LoraLinear: The wrapped linear layer with LoRA functionality.
"""
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
lora_config.r <= linear.in_features
), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
bias = None
if lora_config.lora_train_bias in ["all", "lora"]:
bias = linear.bias
if bias is None:
bias = True
lora_linear = LoraLinear(
linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
)
return lora_linear
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
"""
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
Args:
module (nn.Module): The module to convert to LoRA form.
lora_rank (int): The rank of the LoRA approximation.
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
parent_name (str): The name of the parent module.
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
None
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
if lora_config.target_modules is None or any(
[name in target_module for target_module in lora_config.target_modules]
):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(f"Converting {parent_name}.{name} to LoRA")
setattr(module, name, _lora_linear_wrapper(child, lora_config))
elif isinstance(child, nn.Embedding):
if lora_config.target_modules is None or any(
[name in target_module for target_module in lora_config.target_modules]
):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(f"Converting {parent_name}.{name} to LoRA")
setattr(
module,
name,
LoraEmbedding(
child.weight,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.embedding_lora_dropout,
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
lora_initialization_method=lora_config.lora_initialization_method,
),
)
else:
_convert_to_lora_recursively(child, lora_rank)
_convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
nn.Module: The converted module.
"""
if lora_rank <= 0:
if lora_config.r <= 0:
return module
_convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
# make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
total_parameter_size = 0
for name, p in module.named_parameters():
p.requires_grad = False
if "bias" in name and lora_config.lora_train_bias == "all":
p.requires_grad = True
total_parameter_size += p.numel()
_convert_to_lora_recursively(module, "", lora_config)
trainable_parameter_size = 0
for name, p in module.named_parameters():
if p.requires_grad == True:
trainable_parameter_size += p.numel()
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(
f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
)
return module

View File

@ -5,6 +5,7 @@ loss functions
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from .utils import masked_mean
@ -89,11 +90,22 @@ class DpoLoss(nn.Module):
"""
Dpo loss
Details: https://arxiv.org/pdf/2305.18290.pdf
SimPO loss:
Details: https://arxiv.org/pdf/2405.14734.pdf
"""
def __init__(self, beta: float = 0.1):
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
"""
Args:
beta: The temperature parameter in the DPO paper.
gamma: The margin parameter in the SimPO paper.
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
Refer to the length normalization in the SimPO paper
"""
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(
self,
@ -104,7 +116,7 @@ class DpoLoss(nn.Module):
chosen_mask: torch.Tensor,
reject_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
@ -113,6 +125,8 @@ class DpoLoss(nn.Module):
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
@ -127,13 +141,12 @@ class DpoLoss(nn.Module):
if len(logprob_ref_chosen.shape) == 2:
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
else:
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
ref_logratios = logprob_ref_chosen - logprob_ref_reject
else:
# If no reference model is provided
ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios
logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
# Calculate rewards for logging
@ -168,3 +181,93 @@ class LogExpLoss(nn.Module):
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss
class OddsRatioLoss(nn.Module):
"""
Odds Ratio Loss in ORPO
Details: https://arxiv.org/pdf/2403.07691
"""
def forward(
self,
chosen_logp: torch.Tensor,
reject_logp: torch.Tensor,
chosen_loss_mask: torch.Tensor,
reject_loss_mask: torch.Tensor,
) -> torch.Tensor:
chosen_logp = chosen_logp.to(dtype=torch.float32)
reject_logp = reject_logp.to(dtype=torch.float32)
chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001)
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
log_odds_ratio = chosen_odds_masked - reject_odds_masked
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
class KTOLoss(nn.Module):
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
"""
Args:
beta: The temperature parameter in the KTO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable
"""
super().__init__()
self.beta = beta
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
def forward(
self,
chosen_logps: torch.Tensor,
rejected_logps: torch.Tensor,
kl_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
ref_kl_logps: torch.Tensor,
):
"""
Reference:
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
Compute the KTO loss for a batch of policy and reference model log probabilities.
Args:
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
beta: The temperature parameter in the DPO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable responses.
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
"""
kl = (kl_logps - ref_kl_logps).mean().detach()
# all gather
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
kl = (kl / dist.get_world_size()).clamp(min=0)
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
chosen_logratios = chosen_logps - ref_chosen_logps
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
chosen_losses = torch.Tensor([]).to(kl_logps.device)
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
rejected_logratios = rejected_logps - ref_rejected_logps
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
rejected_losses = torch.Tensor([]).to(kl_logps.device)
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, kl

View File

@ -89,7 +89,9 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
return mean
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
def calc_masked_log_probs(
logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor, length_normalization: bool = False
) -> torch.Tensor:
"""
Calculate the masked log probabilities for a given sequence of logits.
@ -103,7 +105,11 @@ def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mas
"""
# logits are probabilities of the next token, so we shift them to the left by one
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs * mask
if not length_normalization:
return log_probs * mask
else:
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:

View File

@ -1,7 +1,18 @@
from .base import OLTrainer, SLTrainer
from .dpo import DPOTrainer
from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"]
__all__ = [
"SLTrainer",
"OLTrainer",
"RewardModelTrainer",
"SFTTrainer",
"PPOTrainer",
"DPOTrainer",
"ORPOTrainer",
"KTOTrainer",
]

View File

@ -2,6 +2,7 @@
Dpo trainer
"""
import os
from typing import Any, Optional
import torch
@ -25,7 +26,7 @@ from .utils import is_rank_0, to_device
class DPOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Trainer for DPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
@ -53,6 +54,8 @@ class DPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@ -63,7 +66,7 @@ class DPOTrainer(SLTrainer):
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta)
self.actor_loss_fn = DpoLoss(beta, gamma)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@ -71,6 +74,7 @@ class DPOTrainer(SLTrainer):
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
self.length_normalization = length_normalization
def _before_fit(
self,
@ -131,18 +135,21 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
reject_loss_mask[:, -1] = False
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
if self.ref_model is not None:
self.ref_model.eval()
@ -150,14 +157,14 @@ class DPOTrainer(SLTrainer):
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
@ -219,7 +226,7 @@ class DPOTrainer(SLTrainer):
)
self.accumulative_meter.reset()
if (self.num_train_step + 1) % self.save_interval == 0:
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
@ -283,16 +290,16 @@ class DPOTrainer(SLTrainer):
actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
self.ref_model.eval()
@ -300,11 +307,15 @@ class DPOTrainer(SLTrainer):
ref_all_logits = self.ref_model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"].to(torch.float32)
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
@ -314,7 +325,7 @@ class DPOTrainer(SLTrainer):
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
@ -333,4 +344,7 @@ class DPOTrainer(SLTrainer):
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -0,0 +1,318 @@
"""
KTO trainer
"""
import os
from typing import Any, Optional
import torch
import torch.distributed
from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class KTOTrainer(SLTrainer):
"""
Trainer for KTO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
ref_model (Critic): the reference model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
beta (float, defaults to 0.1): the beta parameter in kto loss
desirable_weight (float, defaults to 1.0): the weight for desirable reward
undesirable_weight (float, defaults to 1.0): the weight for undesirable reward
"""
def __init__(
self,
actor: Any,
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
self.beta = beta
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-kto", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "kto")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
self.model.train()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
batch["input_ids"],
batch["attention_mask"],
batch["loss_mask"],
batch["label"],
batch["kl_input_ids"],
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
batch_size = input_ids.size()[0]
# actor logits
with torch.no_grad():
# calculate KL term with KT data
kl_logits = self.model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
logits = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
chosen_index = [i for i in range(batch_size) if label[i] == 1]
rejected_index = [i for i in range(batch_size) if label[i] == 0]
chosen_logprob = logprob[chosen_index]
rejected_logprob = logprob[rejected_index]
with torch.no_grad():
ref_kl_logits = self.ref_model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
ref_logits = self.ref_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
ref_chosen_logprob = ref_logprob[chosen_index]
ref_rejected_logprob = ref_logprob[rejected_index]
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
)
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
batch["input_ids"],
batch["attention_mask"],
batch["loss_mask"],
batch["label"],
batch["kl_input_ids"],
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
batch_size = input_ids.size()[0]
# actor logits
with torch.no_grad():
# calculate KL term with KT data
kl_logits = self.model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
logits = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
chosen_index = [i for i in range(batch_size) if label[i] == 1]
rejected_index = [i for i in range(batch_size) if label[i] == 0]
chosen_logprob = logprob[chosen_index]
rejected_logprob = logprob[rejected_index]
with torch.no_grad():
ref_kl_logits = self.ref_model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
ref_logits = self.ref_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
ref_chosen_logprob = ref_logprob[chosen_index]
ref_rejected_logprob = ref_logprob[rejected_index]
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.update()
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -0,0 +1,314 @@
"""
Orpo trainer
"""
import os
from typing import Any, Optional
import torch
from coati.models.loss import OddsRatioLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class ORPOTrainer(SLTrainer):
"""
Trainer for ORPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
lam (float, defaults to 0.1): the lambda parameter in ORPO loss
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
"""
def __init__(
self,
actor: Any,
booster: Booster,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss()
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "orpo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
self.model.train()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
)
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)
loss = chosen_nll - odds_ratio_loss * self.lam
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.writer.add_scalar(
"train/log_odds_ratio",
self.accumulative_meter.get("log_odds_ratio"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.coordinator.print_on_master("\nStart evaluation...")
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
self.accumulative_meter.reset()
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
)
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
)
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)
loss = chosen_nll - odds_ratio_loss * self.lam
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -237,6 +237,7 @@ class RewardModelTrainer(SLTrainer):
+ f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n"
)
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -102,6 +102,7 @@ class SFTTrainer(SLTrainer):
batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer)
loss_mean = all_reduce_mean(tensor=loss)
@ -113,6 +114,7 @@ class SFTTrainer(SLTrainer):
self.optimizer.zero_grad()
self.scheduler.step()
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
@ -165,6 +167,7 @@ class SFTTrainer(SLTrainer):
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -0,0 +1,9 @@
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
151645,
151643
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
],
"end_of_assistant": "</s>"
}

View File

@ -9,6 +9,7 @@
- [Install Requirements](#install-requirements)
- [Get Start with ColossalRun](#get-start-with-colossalrun)
- [Training Configuration](#training-configuration)
- [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft)
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [Step 1: Data Collection](#step-1-data-collection)
- [Step 2: Preprocessing](#step-2-preprocessing)
@ -29,6 +30,9 @@
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
- [List of Supported Models](#list-of-supported-models)
- [Hardware Requirements](#hardware-requirements)
- [Inference example](#inference-example)
@ -45,9 +49,6 @@
pip install -r requirements.txt
```
## Get Start with ColossalRun
@ -81,8 +82,6 @@ Make sure the master node can access all nodes (including itself) by ssh without
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more details regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
<details><summary><b>Gemini (Zero3)</b></summary>
@ -374,35 +373,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
</details>
<details><summary><b>Low Rank Adaption</b></summary>
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--lora_rank 32 \ # This enables LoRA
--use_wandb
```
</details>
<details><summary><b>Other Training Arguments</b></summary>
@ -427,6 +397,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- use_wandb: if this flag is up, you can view logs on wandb.
</details>
### Parameter Efficient Finetuning (PEFT)
Currently, we have support LoRA (low-rank adaptation) and PiSSA (principal singular values and singular vectors adaptation). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance.
<details><summary><b>Low Rank Adaption and PiSSA</b></summary>
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources.
To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file.
```json
{
"r": 128,
"embedding_lora_dropout": 0.0,
"linear_lora_dropout": 0.1,
"lora_alpha": 32,
"lora_train_bias": "all",
"lora_initialization_method": "PiSSA",
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
}
```
#### Lora Parameters
- r: lora rank
- embedding_lora_dropout: dropout probability for embedding layer
- linear_lora_dropout: dropout probability for linear layer
- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model.
- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases)
- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA.
- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA
--use_wandb
```
</details>
@ -445,7 +469,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
@ -470,9 +494,15 @@ In this code we provide a flexible way for users to set the conversation templat
- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields.
```json
{
"chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
"system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
"end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format,
"chat_template": "A string of chat_template used for formatting chat data",
"system_message": "A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added",
"end_of_assistant": "The token(s) in string that denotes the end of assistance's response",
"stop_ids": "A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training"
}
```
* `chat_template`: (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating.
* `system_message`: A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added.
* `end_of_assistant`: The token(s) in string that denotes the end of assistance's response". For example, in the ChatGLM2 prompt format,
```
<|im_start|>system
system messages
@ -481,15 +511,13 @@ In this code we provide a flexible way for users to set the conversation templat
<|im_start|>user
How far is the moon? <|im_end|>
<|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>...
```
the end_of_assistant tokens are "<|im_end|>"
"stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically
}
```
On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
```
the `end_of_assistant` tokens are "<|im_end|>"
* `stop_ids`: (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically.
On your first run of the data preparation script, you only need to define the `chat_template` (if you want to use custom chat template) and the `system message` (if you want to use a custom system message)
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files.
@ -509,7 +537,7 @@ Human: <s> what are some pranks with a pen i can do?</s> Assistant: <s> Are you
#### Step 3: Training
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
### RLHF Training Stage2 - Training Reward Model
@ -526,7 +554,7 @@ Below shows the preference dataset format used in training the reward model.
[
{"context": [
{
"from": "human",
"from": "user",
"content": "Introduce butterflies species in Oregon."
}
]
@ -551,11 +579,11 @@ Below shows the preference dataset format used in training the reward model.
#### Step 2: Preprocessing
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
#### Step 3: Training
You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
You can run [train_rm.sh](./training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
#### Features and Tricks in RM Training
@ -595,7 +623,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
#### Step 1: Data Collection
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
@ -603,7 +631,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
}
...
@ -626,14 +654,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to
]
```
#### Step 2: Preprocessing
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./data_preparation_scripts/prepare_prompt_dataset.sh)
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
#### Step 3: Training
You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
You can run the [train_ppo.sh](./training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
```bash
@ -717,17 +745,90 @@ For DPO training, you only need the preference dataset. Please follow the instru
#### Step 2: Training
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
You can run the [train_dpo.sh](./training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
```
--beta 0.1 \ # the temperature in DPO loss, Default to 0.1
--gamma 0.0 \ # the reward target margin in the SimPO paper, Default to 0.
--disable_reference_model \ # whether to disable the reference model, if set, the implicit reward will be calculated solely from the actor. Default to enable reference model in DPO
--length_normalization \ # whether to apply length normalization, Default to not use
```
#### DPO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/DPO.png">
</p>
### Alternative Option For RLHF: Simple Preference Optimization
We support the method introduced in the paper [SimPO: Simple Preference Optimization
with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
#### SimPO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/SimPO_margin.png">
</p>
### Alternative Option For RLHF: Odds Ratio Preference Optimization
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
#### ORPO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
</p>
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results.
For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examples/data_preparation_scripts/prepare_kto_dataset.sh). You will need preference data, different from DPO and its derivatives, you no longer need a pair of chosen/rejected response for the same input. You only need data whose response is associated with a preference label--- whether the response is okay or not, read the papre for more details. You also need to convert your data to the following intermediate format before you run the data preparation script.
```jsonl
{
"prompt": [
{
"from": "user",
"content": "What are some praise words in english?"
},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."
},
{
"from": "user",
"content": "What's your favorite one?"
}
],
"completion": {
"from": "assistant",
"content": "impressive."
},
"label": true
}
```
For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
#### KTO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
</p>
## Hardware Requirements
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM.
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
- 2 H800 GPU
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
- zero2, micro batch size=4, VRAM Usage=72390.95 MB
- 4 H800 GPUs
- zero2_cpu, micro batch size=8, VRAM Usage=19412.77 MB
- zero2, micro batch size=8, VRAM Usage=43446.31 MB
- zero2, micro batch size=16, VRAM Usage=58082.30 MB
- zero2, micro batch size=8, lora_rank=8, VRAM Usage=21167.73 MB
- zero2, micro batch size=8, lora_rank=32, VRAM Usage=21344.17 MB
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model (llama2-7B-hf) on a dummy dataset with a sequence length of 2048 and a layout length of 512 with different tp_size (equal to the number of GPUs).
| PPO | tp=8 | tp=4 |
|-------|---------------|---------------|
| bs=1 | 18485.19 MB | 42934.45 MB |
@ -738,12 +839,39 @@ For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM
For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
- 1 H800 GPU
- zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
- zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
- 2 H800 GPU
- zero2-cpu, micro batch size=2, VRAM Usage=36989.37 MB
- zero2-cpu, micro batch size=4, VRAM Usage=48081.67 MB
- 4 H800 GPUs
- zero2, batch size=4, VRAM Usage=67544.47 MB
- zero2, micro batch size=4, VRAM Usage=67483.44 MB
For SimPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
- 2 H800 GPU
- zero2-cpu, micro batch size=4, VRAM 25705.26 MB
- zero2, micro batch size=4, VRAM Usage=73375.04 MB
- 4 H800 GPUs
- zero2_cpu, micro batch size=8, VRAM Usage=36709.36 MB
- zero2, micro batch size=4, VRAM Usage=44330.90 MB
- zero2, micro batch size=8, VRAM Usage=56086.12 MB
For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
- 2 H800 GPU
- zero2-cpu, micro batch size=4, VRAM 26693.38 MB
- zero2, micro batch size=4, VRAM Usage=74332.65 MB
- 4 H800 GPUs
- zero2_cpu, micro batch size=8, VRAM Usage=38709.73 MB
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
- 2 H800 GPU
- zero2-cpu, micro batch size=2, VRAM Usage=35241.98 MB
- zero2-cpu, micro batch size=4, VRAM Usage=38989.37 MB
- 4 H800 GPUs
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
## List of Supported Models

View File

@ -40,7 +40,7 @@ import random
import time
from multiprocessing import cpu_count
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer
@ -56,8 +56,8 @@ def main():
type=str,
required=True,
default=None,
choices=["sft", "prompt", "preference"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
choices=["sft", "prompt", "preference", "kto"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
)
parser.add_argument(
"--data_input_dirs",
@ -199,11 +199,13 @@ def main():
)
if args.type == "sft":
preparation_function = supervised_tokenize_sft
preparation_function = tokenize_sft
elif args.type == "prompt":
preparation_function = tokenize_prompt_dataset
preparation_function = tokenize_prompt
elif args.type == "preference":
preparation_function = tokenize_rlhf
elif args.type == "kto":
preparation_function = tokenize_kto
else:
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
@ -228,10 +230,13 @@ def main():
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
dataset = dataset.filter(
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
)
if args.type == "kto":
filter_by = "completion"
elif args.type == "preference":
filter_by = "chosen_input_ids"
else:
filter_by = "input_ids"
dataset = dataset.filter(lambda data: data[filter_by] is not None)
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)

View File

@ -0,0 +1,14 @@
SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type kto \
--data_input_dirs /PATH/TO/KTO/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@ -5,9 +5,10 @@ rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type preference \
--data_input_dirs "PATH/TO/PREFERENCE/DATA" \
--data_input_dirs /PATH/TO/PREFERENCE/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@ -10,4 +10,5 @@ python prepare_dataset.py --type prompt \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@ -5,9 +5,10 @@ rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type sft \
--data_input_dirs "PATH/TO/SFT/DATA" \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 4096

View File

@ -0,0 +1,104 @@
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, lets hear a story. </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, lets hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
==========
==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, lets hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
==========
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
==========
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story youd like to hear? </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story youd like to hear? </s><s>[INST] about Donald Trump [/INST] Id be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question youd like to ask about Donald Trump? </s>
==========
==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story youd like to hear? </s><s>[INST] about Donald Trump [/INST] Id be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question youd like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
==========

View File

@ -1,4 +1,4 @@
pandas>=1.4.1
sentencepiece
colossalai
colossalai==0.4.0
prompt_toolkit

View File

@ -1,5 +1 @@
XXX.XX.XXX.XXX # Your master IP
XXX.XX.XXX.XXX # Your slave IPs
XXX.XX.XXX.XXX # Your slave IPs
XXX.XX.XXX.XXX # Your slave IPs
XXX.XX.XXX.XXX # Your slave IPs
localhost

View File

@ -0,0 +1,9 @@
{
"r": 128,
"embedding_lora_dropout": 0.0,
"linear_lora_dropout": 0.1,
"lora_alpha": 32,
"lora_train_bias": "all",
"lora_initialization_method": "PiSSA",
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
}

View File

@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module, disable_dropout
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import DPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -115,8 +118,8 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(model)
if args.enable_reference_model:
if not args.disable_reference_model:
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
@ -125,18 +128,20 @@ def train(args):
)
else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(ref_model)
else:
ref_model = None
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
disable_dropout(model)
disable_dropout(ref_model)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
model.gradient_checkpointing_enable()
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@ -178,6 +183,21 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
@ -255,25 +275,28 @@ def train(args):
save_interval=args.save_interval,
save_dir=args.save_dir,
coordinator=coordinator,
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -296,6 +319,10 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss")
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
parser.add_argument("--length_normalization", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
@ -304,33 +331,39 @@ if __name__ == "__main__":
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--save_dir", type=str, default="output")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--enable_reference_model", type=bool, default=True)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# fool proof hyperparameter setup
if args.loss_type == "simpo_loss":
args.length_normalization = True
args.gamma = args.gamma if args.gamma > 0 else 1.4
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -13,50 +13,52 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
# export CUDA_VISIBLE_DEVICES=6
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="dpo"
PROJECT_NAME="DPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
YOUR/DATA/DIR/arrow/part-00000
YOUR/DATA/DIR/arrow/part-00001
YOUR/DATA/DIR/arrow/part-00002
YOUR/DATA/DIR/arrow/part-00003
YOUR/DATA/DIR/arrow/part-00004
YOUR/DATA/DIR/arrow/part-00005
YOUR/DATA/DIR/arrow/part-00006
YOUR/DATA/DIR/arrow/part-00007
YOUR/DATA/DIR/arrow/part-00008
YOUR/DATA/DIR/arrow/part-00009
/Your/Preference/Data/arrow/part-00000
/Your/Preference/Data/arrow/part-00001
/Your/Preference/Data/arrow/part-00002
/Your/Preference/Data/arrow/part-00003
/Your/Preference/Data/arrow/part-00004
/Your/Preference/Data/arrow/part-00005
/Your/Preference/Data/arrow/part-00006
/Your/Preference/Data/arrow/part-00007
/Your/Preference/Data/arrow/part-00008
/Your/Preference/Data/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--checkpoint_path $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 1 \
--accumulation_steps 4 \
--batch_size 2 \
--accumulation_steps 2 \
--batch_size 16 \
--lr 1e-6 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 4096 \
--weight_decay 0.01 \
--warmup_steps 100 \
--warmup_steps 60 \
--grad_checkpoint \
--use_wandb

View File

@ -0,0 +1,376 @@
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import KTOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
disable_dropout(ref_model)
disable_dropout(model)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
num_desirable = 0
num_undesirable = 0
for i in range(len(train_dataset)):
if train_dataset[i]["label"]:
num_desirable += 1
else:
num_undesirable += 1
logger.info(f"Dataset Statistics:\nDesirable: {num_desirable}\nUndesirable: {num_undesirable}")
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
if actual_ratio < 1 or actual_ratio > 4 / 3:
if not args.auto_weight:
raise AssertionError(
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase/decrease desirable weight or decrease/increase undesirable weight."
)
else:
args.desirable_weight = args.desirable_weight / actual_ratio
coordinator.print_on_master(
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, auto weight is enabled, set desirable weight to {args.desirable_weight} and undesirable weight to {args.undesirable_weight}"
)
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataloader = plugin.prepare_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optim,
total_steps=args.max_epochs * num_update_steps_per_epoch,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
if ref_model is not None:
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
booster.load_model(model, args.checkpoint_path)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
trainer = KTOTrainer(
actor=model,
ref_model=ref_model,
booster=booster,
actor_optim=optim,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
start_epoch=start_epoch,
save_interval=args.save_interval,
save_dir=args.save_dir,
coordinator=coordinator,
beta=args.beta,
desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--auto_weight", default=False, action="store_true")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -0,0 +1,65 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
/Your/KTO/Data/arrow/part-00000
/Your/KTO/Data/arrow/part-00001
/Your/KTO/Data/arrow/part-00002
/Your/KTO/Data/arrow/part-00003
/Your/KTO/Data/arrow/part-00004
/Your/KTO/Data/arrow/part-00005
/Your/KTO/Data/arrow/part-00006
/Your/KTO/Data/arrow/part-00007
/Your/KTO/Data/arrow/part-00008
/Your/KTO/Data/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--auto_weight \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 1024 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint

View File

@ -0,0 +1,341 @@
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import ORPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
disable_dropout(model)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataloader = plugin.prepare_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optim,
total_steps=args.max_epochs * num_update_steps_per_epoch,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
booster.load_model(model, args.checkpoint_path)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
trainer = ORPOTrainer(
actor=model,
booster=booster,
actor_optim=optim,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
start_epoch=start_epoch,
save_interval=args.save_interval,
save_dir=args.save_dir,
coordinator=coordinator,
lam=args.lam,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -0,0 +1,64 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="ORPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
/Your/Preference/Data/arrow/part-00000
/Your/Preference/Data/arrow/part-00001
/Your/Preference/Data/arrow/part-00002
/Your/Preference/Data/arrow/part-00003
/Your/Preference/Data/arrow/part-00004
/Your/Preference/Data/arrow/part-00005
/Your/Preference/Data/arrow/part-00006
/Your/Preference/Data/arrow/part-00007
/Your/Preference/Data/arrow/part-00008
/Your/Preference/Data/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 3 \
--accumulation_steps 1 \
--batch_size 16 \
--lr 8e-6 \
--lam 0.5 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 1024 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_wandb

View File

@ -13,7 +13,7 @@ from coati.dataset import (
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -31,8 +31,11 @@ logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -81,20 +84,26 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain)
critic = Critic(args.rm_pretrain)
if args.lora_config is not None:
actor = convert_to_lora_module(actor, lora_config=lora_config)
critic = convert_to_lora_module(critic, lora_config=lora_config)
for name, module in actor.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
for name, module in critic.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
lora_manager.able_to_merge = False
# Disable dropout
disable_dropout(actor)
disable_dropout(critic)
if args.lora_rank > 0:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
actor.gradient_checkpointing_enable()
critic.model.gradient_checkpointing_enable()
if args.grad_checkpoint:
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@ -421,11 +430,9 @@ def train(args):
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
lora_manager.able_to_merge = True
actor.eval()
critic.eval()
# save model checkpoint after fitting on only rank0
@ -484,11 +491,9 @@ if __name__ == "__main__":
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=4)
parser.add_argument("--lora_train_bias", type=str, default="none")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1)

View File

@ -15,10 +15,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="ppo"
PROJECT_NAME="PPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@ -54,7 +53,7 @@ declare -a ptx_dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \
--pretrain $PRETRAINED_MODEL_PATH \

View File

@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoTokenizer
@ -16,14 +16,20 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -55,9 +61,11 @@ def train(args):
args.pretrain,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
# ==============================
# Initialize Booster
# ==============================
@ -119,11 +127,9 @@ def train(args):
booster = Booster(plugin=plugin)
if args.grad_checkpoint and args.lora_rank == 0:
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
if args.grad_checkpoint:
model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@ -173,6 +179,22 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)
@ -253,21 +275,21 @@ def train(args):
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -297,33 +319,28 @@ if __name__ == "__main__":
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--save_dir", type=str, default="output")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -15,10 +15,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="rm"
PROJECT_NAME="RM"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@ -38,17 +38,18 @@ declare -a dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
--pretrain $PRETRAINED_MODEL_PATH \
--checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 3 \
--accumulation_steps 1 \
--batch_size 8 \

View File

@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module
from coati.models import LoraConfig, convert_to_lora_module
from coati.trainer import SFTTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -24,8 +24,11 @@ logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -53,15 +56,19 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
@ -114,6 +121,15 @@ def train(args):
booster = Booster(plugin=plugin)
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
@ -122,12 +138,10 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing
model.gradient_checkpointing_enable()
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(
@ -151,15 +165,6 @@ def train(args):
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
@ -175,6 +180,23 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
@ -202,6 +224,7 @@ def train(args):
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -257,22 +280,21 @@ def train(args):
trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=None,
eval_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
if args.save_path is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -302,32 +324,27 @@ if __name__ == "__main__":
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--save_path", type=str, default="output")
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -13,13 +13,11 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
# export CUDA_VISIBLE_DEVICES=4,5,6
set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="sft"
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="SFT"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
@ -38,28 +36,25 @@ declare -a dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
echo $(which colossalai)
echo $(which python)
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 4000 \
--save_interval 2000 \
--dataset ${dataset[@]} \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--lora_rank 0 \
--plugin 3d \
--tp 2 \
--pp 1 \
--zero_stage 0 \
--batch_size 2 \
--max_epochs 3 \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
--accumulation_steps 1 \
--lr 5e-5 \
--max_len 400 \
--max_len 4096 \
--use_flash_attn \
--grad_checkpoint \
--use_wandb \
--use_flash_attn
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \

View File

@ -1,9 +1,9 @@
transformers>=4.36.2
transformers==4.39.3
tqdm
datasets==2.14.7
loralib
colossalai>=0.3.7
torch>=1.12.1
colossalai==0.4.0
torch>=2.1.0
langchain
tokenizers
fastapi

View File

@ -4,7 +4,7 @@ import os
sft_seed = {
"messages": [
{"from": "human", "content": "Give three tips for staying healthy."},
{"from": "user", "content": "Give three tips for staying healthy."},
{
"from": "assistant",
"content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
@ -13,7 +13,7 @@ sft_seed = {
}
prompt_seed = {
"messages": [
{"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
{"from": "user", "content": "Describe the impacts of climate change on communities living in coastal areas."},
{
"from": "assistant",
"content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
@ -22,21 +22,34 @@ prompt_seed = {
}
preference_seed = {
"context": [
{"from": "human", "content": "What kind of noises did dinosaurs make?"},
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
{
"from": "assistant",
"content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
},
{"from": "human", "content": "yes they did"},
{"from": "user", "content": "yes they did"},
{
"from": "assistant",
"content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
},
{"from": "human", "content": "you cant read"},
{"from": "user", "content": "you cant read"},
],
"chosen": [{"from": "assistant", "content": "You can read?"}],
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
}
kto_seed = {
"prompt": [
{"from": "user", "content": "What are some praise words in english?"},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ...",
},
{"from": "user", "content": "What's your favorite one?"},
],
"completion": {"from": "assistant", "content": "Impressive."},
"label": True,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -61,12 +74,21 @@ if __name__ == "__main__":
seed = prompt_seed
elif args.data_type == "preference":
seed = preference_seed
elif args.data_type == "kto":
seed = kto_seed
else:
raise ValueError(f"Unknown data type {args.data_type}")
line = json.dumps(seed, ensure_ascii=False) + "\n"
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
if args.data_type != "kto":
line = json.dumps(seed, ensure_ascii=False) + "\n"
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
f.write(line)
f.write(line)
f.write(line)
else:
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
seed["label"] = not seed["label"]
line = json.dumps(seed, ensure_ascii=False) + "\n"
f.write(line)

View File

@ -1 +1 @@
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}

View File

@ -0,0 +1 @@
{"prompt": [{"from": "user", "content": "What are some praise words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "impressive."},"label": true}

View File

@ -1 +1 @@
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}

View File

@ -71,6 +71,8 @@ get_data_input_dirs() {
echo "$PROMPT_DATASET"
elif [[ $data_type == "preference" ]]; then
echo "$PREFERENCE_DATASET"
elif [[ $data_type == "kto" ]]; then
echo "$KTO_DATASET"
else
echo "Unknown data type $data_type"
exit 1
@ -121,6 +123,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt) \
--data_type "prompt"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs kto) \
--data_type "kto"
echo "[Test]: testing prepare_preference_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working
@ -258,3 +264,50 @@ for model in ${MODELS[@]}; do
exit 1
fi
done
echo "[Test]: testing prepare_kto_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working
SKIPPED_TESTS=(
)
# test prepare_kto_dataset
for model in ${MODELS[@]}; do
data_type="kto"
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
echo "[Test]: Skipped $model-$data_type"
continue
fi
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
data_input_dirs=$(get_data_input_dirs $data_type)
tokenizer_dir=$(get_tokenizer_dirs $model)
conversation_template=$(get_conversation_template_config $model)
for i in $(seq $NUM_RETRY); do
rm -rf $cache_dir
rm -rf $jsonl_dir
rm -rf $arrow_dir
echo "[Test]: $model-$data_type, attempt $i"
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
--type kto \
--data_input_dirs $data_input_dirs \
--conversation_template_config $conversation_template \
--tokenizer_dir $tokenizer_dir \
--data_cache_dir $cache_dir \
--data_jsonl_output_dir $jsonl_dir \
--data_arrow_output_dir $arrow_dir \
--max_length 400 \
--num_samples_per_datafile 100 \
--num_spliced_dataset_bins 1
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$data_type"
exit 1
fi
done

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from coati.models import convert_to_lora_module
from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
from torch.utils.data import DataLoader, TensorDataset
@ -38,7 +39,7 @@ def test_overfit():
# Build and convert model
model = SimpleNN(input_size, hidden_size, num_classes)
weight_to_compare = model.fc1.weight.detach().clone()
model = convert_to_lora_module(model, lora_rank=30)
model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
@ -50,7 +51,6 @@ def test_overfit():
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
print(loss)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
@ -65,5 +65,50 @@ def test_overfit():
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
def test_lora_linear_accuracy():
weight = torch.randn(10, 5)
linear = nn.Linear(5, 10)
linear.weight.data = weight
x = torch.randn(10, 5)
out_linear = linear(x)
# lora linear Pissa
linear.weight.data = weight
lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
out_lora = lora_linear(x)
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
# lora linear
linear.weight.data = weight
lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
out_lora = lora_linear(x)
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
def test_lora_embedding_accuracy():
weight = torch.randn(10, 5)
embedding = nn.Embedding(10, 5)
embedding.weight.data = weight
x = torch.randint(0, 10, (10,))
out_embedding = embedding(x)
# lora embedding Pissa
embedding.weight.data = weight
lora_embedding = LoraEmbedding(
embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
)
out_lora = lora_embedding(x)
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
# lora embedding
embedding.weight.data = weight
lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
out_lora = lora_embedding(x)
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
if __name__ == "__main__":
test_overfit()
test_lora_linear_accuracy()
test_lora_embedding_accuracy()

View File

@ -94,7 +94,7 @@ done
# Test DPO/PPO data Preparation
for model in ${MODELS[@]}; do
echo "Testing DPO/PPO data templating for $model"
echo "Testing DPO/RM data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating for $model"
echo "[Test]: Failed in the DPO/RM data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating test for $model"
echo "[Test]: Failed in the DPO/RM data templating test for $model"
exit 1
fi
done
# Test KTO data Preparation
for model in ${MODELS[@]}; do
echo "Testing KTO data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/kto/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating test for $model"
exit 1
fi
done

View File

@ -30,9 +30,10 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
export OMP_NUM_THREADS=8
@ -112,6 +113,11 @@ for lora_rank in ${LORA_RANK[@]}; do
sp='1'
sp_mode='split_gather'
enable_sequence_parallelism=''
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
@ -173,9 +179,10 @@ for lora_rank in ${LORA_RANK[@]}; do
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@ -192,8 +199,8 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -229,6 +236,11 @@ for lora_rank in ${LORA_RANK[@]}; do
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
@ -248,9 +260,10 @@ for lora_rank in ${LORA_RANK[@]}; do
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@ -262,8 +275,8 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -306,6 +319,11 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='4'
ebs='8'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='4'
bs='16'
@ -342,7 +360,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
--lora_rank $lora_rank \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
@ -361,8 +379,8 @@ for lora_rank in ${LORA_RANK[@]}; do
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -402,6 +420,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
@ -423,9 +446,10 @@ for lora_rank in ${LORA_RANK[@]}; do
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@ -437,8 +461,176 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing ORPO ..."
SKIPPED_TESTS=(
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto doesn't support generation
# (need to calculate ref_model logits through forwarding in inference mode)
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a dataset=()
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing KTO ..."
SKIPPED_TESTS=(
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto doesn't support generation
# (need to calculate ref_model logits through forwarding in inference mode)
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a dataset=()
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
--auto_weight \
--desirable_weight 1.2 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done

View File

@ -62,3 +62,11 @@ if __name__ == "__main__":
assert any(
[rejected_lable in s for s in to_verify_lable_rejected]
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
elif args.data_type == "kto":
sample = data[0]
to_verify_data = to_verify_data[0]
for line in sample["prompt"]:
assert line["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["completion_decode"]
assert sample["label"] == to_verify_data["label"]

View File

@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
files = glob.glob(os.path.join(path, "*.jsonl"))

View File

@ -1,6 +1,9 @@
from abc import abstractstaticmethod
from colossal_eval.utils import jdump
from torch.utils.data import Dataset
from colossalai.logging import DistributedLogger
class BaseDataset:
@ -12,13 +15,24 @@ class BaseDataset:
logger: Logger for the dataset.
"""
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
def __init__(self, path, logger, *args, **kwargs):
self.dataset = self.load(path, logger, *args, **kwargs)
def save(self, save_path):
"""Save the converted dataset"""
jdump(self.dataset, save_path)
@abstractstaticmethod
def load(path, logger):
def load(path, logger: DistributedLogger, *args, **kwargs):
"""Load the original dataset and convert it into the inference dataset"""
class DistributedDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))

View File

@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))

View File

@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
data = jload(path)
data_per_category = get_data_per_category(data)

View File

@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl")
data_list = []

View File

@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
files = os.listdir(os.path.join(path, "data", category))

View File

@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
files = os.listdir(path)

View File

@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset):
"""
@staticmethod
def load(
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))

View File

@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset):
This dataset class will convert the original dataset into the inference dataset.
"""
def __init__(self, path, logger, few_shot):
def __init__(self, path, logger: DistributedLogger, *args, **kwargs):
self.multiturn = True
self.dataset = self.load(path, logger, few_shot)
self.dataset = self.load(path, logger, *args, **kwargs)
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": defaultdict(dict)}
file_path = os.path.join(path, "question.jsonl")

View File

@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files:

View File

@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset):
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files:

View File

@ -1,11 +1,11 @@
import copy
import math
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
@ -130,7 +130,7 @@ class HuggingFaceModel(BaseModel):
if shard_config is not None:
self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device())
if peft_path is not None:
@ -325,7 +325,7 @@ class HuggingFaceModel(BaseModel):
return input_ids_list, labels_list, None
def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
"""
Infer the given data.
This function will call self.generate() to get model outputs and also self.model() to get logits.
@ -359,26 +359,23 @@ class HuggingFaceModel(BaseModel):
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
turn_desc = "" if turn == 0 else f"-turn{turn}"
bar = tqdm(
range(math.ceil(len(data) / self.batch_size)),
desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
range(len(data_loader)),
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
answers = copy.deepcopy(data)
for i in range(0, len(data), self.batch_size):
batch = data[i : i + self.batch_size]
answers = []
for i, batch in enumerate(data_loader):
batch_prompt, batch_target = get_batch_prompt(
self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
)
if is_rank_0() and debug and i == 0:
self.logger.info(
f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}"
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
)
self.logger.info("-" * 120)
self.logger.info("An example prompt and prompt with target is:")
@ -402,7 +399,7 @@ class HuggingFaceModel(BaseModel):
# Otherwise this will violate the single-choice setting.
if calculate_loss:
labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))]
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
@ -411,29 +408,30 @@ class HuggingFaceModel(BaseModel):
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
]
for j in range(len(batch_prompt)):
for j in range(len(batch)):
if not pretrain:
if isinstance(answers[i + j]["output"], list):
answers[i + j]["output"].append(batch_decodes[j].strip())
if isinstance(batch[j]["output"], list):
batch[j]["output"].append(batch_decodes[j].strip())
else:
answers[i + j]["output"] = batch_decodes[j].strip()
batch[j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
answers[i + j]["logits_over_choices"] = probs[j]
batch[j]["logits_over_choices"] = probs[j]
if calculate_loss:
answers[i + j]["loss_over_choices"] = loss_over_choices[j]
batch[j]["loss_over_choices"] = loss_over_choices[j]
if calculate_loss:
answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
# However, loss (which is per sample loss) suffices for most cases.
answers[i + j]["loss_sum"] = batch_losses[j]
answers[i + j]["token_num"] = batch_target_token_nums[j]
batch[j]["loss_sum"] = batch_losses[j]
batch[j]["token_num"] = batch_target_token_nums[j]
if batch_bytes_nums:
answers[i + j]["byte_num"] = batch_bytes_nums[j]
batch[j]["byte_num"] = batch_bytes_nums[j]
answers.extend(batch)
bar.update()
@ -600,7 +598,7 @@ class HuggingFaceCausalLM(HuggingFaceModel):
if shard_config is not None:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device())
if peft_path is not None:

View File

@ -123,15 +123,13 @@ class Conversation:
}
def get_few_shot_prefix(
conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
) -> str:
def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str:
"""
Get few shot prefix.
Args:
conv: Conversation template.
few_shot_examples: Few shot examples to generate few shot prompt prefix.
few_shot_data: Few shot examples to generate few shot prompt prefix.
tokenizer: tokenizer used to tokenize data.
Returns:
Few shot prompt prefix.
@ -157,7 +155,6 @@ def get_batch_prompt(
batch: List[Dict],
few_shot_data: List[str],
tokenizer: Optional[AutoTokenizer],
language: Optional[str],
model_max_length: Optional[int],
) -> Tuple[List[Dict], List[Dict]]:
"""
@ -167,6 +164,7 @@ def get_batch_prompt(
conv: Conversation template.
batch: Batch data to generate prompt from.
few_shot_data: Few shot data to generate few shot prompt prefix.
tokenizer: tokenizer used to tokenize data.
Returns:
Tuple containg batch prompt and target.
@ -192,7 +190,7 @@ def get_batch_prompt(
else:
raise Exception("When using few-shot, target answer should be a string.")
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens)
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
conv.append_message(conv.roles[1], None)

View File

@ -5,6 +5,8 @@ from typing import Dict, List
import torch.distributed as dist
from colossal_eval import dataset, models, utils
from colossal_eval.dataset.base import DistributedDataset
from torch.utils.data import DataLoader, DistributedSampler
import colossalai
from colossalai.accelerator import get_accelerator
@ -13,6 +15,7 @@ from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig
logger = get_dist_logger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def rm_and_merge(
@ -54,7 +57,8 @@ def rm_and_merge(
)
else:
rank_answers = utils.jload(directory)
answers["data"].extend(rank_answers["data"])
deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]]
answers["data"].extend(deduplidate_answers)
answers["inference_kwargs"] = rank_answers["inference_kwargs"]
for r in range(dp_size):
@ -65,7 +69,7 @@ def rm_and_merge(
os.remove(directory)
except Exception as e:
print(e)
print(len(answers["data"]))
all_answers[category] = answers
all_answers_with_dataset_class["inference_results"] = all_answers
@ -108,7 +112,12 @@ def main(args):
tp_rank = coordinates[TP_AXIS]
shard_config = (
ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=args.tp_size > 1,
parallel_output=False,
enable_all_optimization=True,
)
if args.tp_size > 1
else None
)
@ -183,6 +192,7 @@ def main(args):
model_name = model_parameter["name"]
model_class = eval(f"models.{model_parameter['model_class']}")
paramerters = model_parameter["parameters"]
batch_size = paramerters["batch_size"]
paramerters.update({"logger": logger})
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
paramerters.update({"shard_config": shard_config})
@ -192,7 +202,6 @@ def main(args):
raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
for dataset_name, split_data in inference_data.items():
start = 0
prev_questions = None
for category, category_data in split_data.items():
num_turn = category_data["inference_kwargs"].get("turns", 1)
@ -201,26 +210,33 @@ def main(args):
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
answers_to_dump = copy.deepcopy(category_data)
partition_size = len(category_data["data"]) // dp_size
redundant = len(category_data["data"]) % dp_size
# Ensure that the amount of data for inference is as consistent as possible across different processes.
lengths = [partition_size for _ in range(dp_size)]
for j in range(redundant):
lengths[(j + start) % dp_size] += 1
start = (start + redundant) % dp_size
for turn in range(num_turn):
if turn == 0:
questions = category_data["data"][
sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
]
dist_dataset = DistributedDataset(category_data["data"])
else:
questions = prev_questions
dist_dataset = DistributedDataset(prev_questions)
sampler = DistributedSampler(
dist_dataset,
num_replicas=pg_mesh.size(DP_AXIS),
rank=pg_mesh.coordinate(DP_AXIS),
shuffle=False,
)
questions_loader = DataLoader(
dist_dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=8,
pin_memory=True,
collate_fn=lambda x: x,
)
category_data["inference_kwargs"]["dataset"] = dataset_name
category_data["inference_kwargs"]["category"] = category
answers_per_rank = model_.inference(
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
data_loader=questions_loader,
inference_kwargs=category_data["inference_kwargs"],
debug=debug_args[dataset_name],
)
prev_questions = answers_per_rank

View File

@ -2,7 +2,7 @@ import ctypes
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from types import MethodType
@ -30,11 +30,15 @@ from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .pp_plugin_base import PipelinePluginBase
@ -61,6 +65,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
@ -69,6 +74,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.sp_group = sp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@ -106,6 +112,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
@ -197,7 +209,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
with self._wait_all_gather():
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
@ -205,6 +218,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = module.module
return module
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
@ -235,7 +255,7 @@ def get_param_info(optim: Optimizer):
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module):
def reinitialize_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
@ -257,7 +277,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
reinitialize_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@ -478,7 +498,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
init_pipeline_optimizer(optim, model)
reinitialize_optimizer(optim, model)
super().__init__(
optim,
precision=precision,
@ -632,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@ -650,6 +671,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
self.model = model
self.param_info = param_info
@ -658,11 +680,12 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
reinitialize_optimizer(optimizer, model)
super().__init__(
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
pg_to_param_list=pg_to_param_list,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
@ -677,6 +700,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
)
def sync_dp_grads(self):
@ -993,9 +1017,11 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@ -1038,17 +1064,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
(
self.dp_axis,
self.pp_axis,
self.tp_axis,
self.sp_axis,
) = (
0,
1,
2,
3,
)
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
@ -1148,6 +1164,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
)
self.max_norm = max_norm
@ -1176,7 +1193,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return True
def support_lora(self) -> bool:
return False
return True
def control_checkpoint_io(self) -> bool:
return True
@ -1210,6 +1227,7 @@ class HybridParallelPlugin(PipelinePluginBase):
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
@ -1224,6 +1242,7 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
@ -1306,7 +1325,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
@ -1362,15 +1381,15 @@ class HybridParallelPlugin(PipelinePluginBase):
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
Returns:`
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset,
num_replicas=self.pg_mesh.size(self.dp_axis),
rank=self.pg_mesh.coordinate(self.dp_axis),
num_replicas=self.dp_group.size(),
rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()),
shuffle=shuffle,
)
@ -1402,6 +1421,24 @@ class HybridParallelPlugin(PipelinePluginBase):
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> Module:
raise NotImplementedError
from peft import PeftModel, get_peft_model
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model

View File

@ -2,6 +2,7 @@ import enum
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import MethodType
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
with ctx:
return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
self.lora_enabled = False
@ -406,7 +448,7 @@ class LowLevelZeroPlugin(DPPluginBase):
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
@ -433,7 +475,9 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)
# TODO: Support Galore + ZeRO
zero_stage = self.stage

View File

@ -1,9 +1,8 @@
import random
import warnings
from collections import defaultdict
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
from typing import Callable, List, Optional, OrderedDict, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -11,34 +10,42 @@ from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.plugin.hybrid_parallel_plugin import (
PRECISION_TORCH_TYPE,
SUPPORT_SP_MODE,
HybridParallelAMPOptimizer,
HybridParallelModule,
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
HybridParallelZeroOptimizer,
get_param_info,
init_pipeline_optimizer,
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
from colossalai.shardformer.shard.shard_config import ShardConfig
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
tp_process_group: Optional[ProcessGroup], # if using tp
pp_process_group: Optional[ProcessGroup], # if using pp
moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
@ -51,37 +58,25 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
overlap_allgather: bool = False,
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: [],
moe_extra_dp_process_group: [],
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
for param in model.parameters():
if is_moe_tensor(param):
pg_param_list[moe_extra_dp_process_group].append(param)
else:
pg_param_list[dp_process_group].append(param)
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
super().__init__(
model=model,
optimizer=optimizer,
pg_to_param_list=pg_param_list,
use_pipeline=use_pipeline,
param_info=param_info,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@ -96,30 +91,37 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
tp_process_group=tp_process_group,
pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
pg_to_param_list=pg_param_list,
overlap_allgather=overlap_allgather,
)
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Plugin for Moe Hybrid Parallel Training.
Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import MoeHybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
model, train_dataset, optimizer, criterion = ...
plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
```
Args:
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
ep_size (int): The size of expert parallelism
sp_size (int): The size of sequence parallelism.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
@ -132,7 +134,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
@ -155,15 +159,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
"""
def __init__(
self,
tp_size: int,
pp_size: int,
ep_size: int,
tp_size: int = 1,
sp_size: int = 1,
sp_size: int = None,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
@ -171,7 +181,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
@ -191,27 +203,61 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
use_ep_inside: bool = True,
overlap_communication: bool = False,
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpointIO] = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None:
world_size = dist.get_world_size()
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
if overlap_communication or zero_stage == 2:
overlap_communication = False
zero_stage = 1
warnings.warn(
f"overlap_communication and zero_stage are set to False and 1 because "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
)
assert (
world_size % (tp_size * pp_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
world_size % (tp_size * pp_size * ep_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
)
assert (
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
assert (
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
sp_size == 1 or sp_size is None
), f"You should not set sp_size when sequence parallelism is not enabled."
self.sp_size = 1
self.dp_size = world_size // (tp_size * pp_size)
assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}"
self.moe_dp_size = self.dp_size // ep_size
self.ep_size = ep_size
self.tp_size = tp_size
self.pp_size = pp_size
self.ep_size = ep_size
self.sp_size = sp_size
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@ -220,61 +266,69 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
logger = get_dist_logger()
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
# we change pg mesh to (pp, dp, tp) for better moe performance
assert (
self.ep_size <= self.dp_size
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
self.moe_dp_size = self.dp_size // self.ep_size
self.use_ep_inside = use_ep_inside
if self.use_ep_inside:
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
if moe_dp_outside:
self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size)
else:
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
logger.info(
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
)
self.tp_group = self.pg_mesh.get_group_along_axis(
self.tp_axis
) # TODO: support custom tp size for mixtral lm head
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.custom_policy = custom_policy
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
else:
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis])
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
ep_group=self.ep_group,
moe_dp_group=self.moe_dp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
@ -282,8 +336,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
ep_group=self.ep_group,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
initial_scale=initial_scale,
@ -310,77 +367,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
)
self.max_norm = max_norm
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset,
num_replicas=self.dp_size,
rank=dist.get_rank(self.global_dp_group),
shuffle=shuffle,
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpointIO(
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)
else:
self.checkpoint_io = self.checkpoint_io(
self.global_dp_group,
self.pp_group,
self.tp_group,
ep_group=self.ep_group,
moe_dp_group=self.moe_dp_group,
zero_stage=self.zero_stage,
)
if hasattr(self.checkpoint_io, "moe_info"):
self.checkpoint_io.moe_info = self.moe_info
return self.checkpoint_io
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)
def configure(
self,
@ -391,13 +387,40 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
)
self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.global_dp_group,
dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@ -405,7 +428,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
@ -418,20 +447,30 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
if self.dp_size <= 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.global_dp_group,
dp_process_group=dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_extra_dp_process_group=self.moe_dp_group,
moe_dp_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
@ -440,4 +479,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
optim = optimizer.optim
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
if os.path.isfile(checkpoint):
@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
if self.dp_rank != 0:
@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
strict = False
model_before_wrapping = model
model = model.unwrap()
@ -943,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_[k] = v.detach().clone().to(device)
return state_
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)

View File

@ -151,13 +151,10 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
# ep_rank 0 saves all the parameters and buffers.
# other ep_ranks save only experts
ep_param_pattern = "experts." if self.ep_rank != 0 else None
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = MoECheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0

View File

@ -44,7 +44,7 @@ class DistCoordinator(metaclass=SingletonMeta):
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
self._local_rank = os.environ.get("LOCAL_RANK", -1)
self._local_rank = int(os.environ.get("LOCAL_RANK", -1))
@property
def rank(self) -> int:

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import GroupMember
def prod(nums: List[int]) -> int:
@ -47,7 +48,7 @@ class ProcessGroupMesh:
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self):
@ -136,7 +137,7 @@ class ProcessGroupMesh:
assert mode in ["raise", "wrap", "clip"]
return int(np.ravel_multi_index(coord, shape, mode))
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
@ -147,10 +148,11 @@ class ProcessGroupMesh:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
if group is not GroupMember.NON_GROUP_MEMBER:
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
@ -238,7 +240,7 @@ class ProcessGroupMesh:
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self.get_group(ranks_in_group, backend=backend)
group = self._get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group

View File

@ -18,7 +18,7 @@
## 📌 Introduction
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}
# Distrifusion
@InProceedings{Li_2024_CVPR,
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2024},
pages={7183-7193}
}
```

View File

@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers.generation import GenerationConfig
@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""
# NOTE: arrange configs according to their importance and frequency of usage
@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
patched_parallelism_size: int = 1 # for distrifusion
# pipeFusion_m_size: int = 1 # for pipefusion
# pipeFusion_n_size: int = 1 # for pipefusion
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
# check Distrifusion
# TODO(@lry89757) need more detailed check
if self.patched_parallelism_size > 1:
# self.use_patched_parallelism = True
self.tp_size = (
self.patched_parallelism_size
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
# check prompt template
if self.prompt_template is None:
return
@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config
@ -396,3 +411,50 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
@dataclass
class DiffusionGenerationConfig:
"""
Param for diffusion model forward
"""
prompt_2: Optional[Union[str, List[str]]] = None
prompt_3: Optional[Union[str, List[str]]] = None
height: Optional[int] = None
width: Optional[int] = None
num_inference_steps: int = None
timesteps: List[int] = None
guidance_scale: float = None
negative_prompt: Optional[Union[str, List[str]]] = (
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
)
negative_prompt_2: Optional[Union[str, List[str]]] = None
negative_prompt_3: Optional[Union[str, List[str]]] = None
num_images_per_prompt: Optional[int] = None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
latents: Optional[torch.FloatTensor] = None
prompt_embeds: Optional[torch.FloatTensor] = None
negative_prompt_embeds: Optional[torch.FloatTensor] = None
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
output_type: Optional[str] = None # "pil"
return_dict: bool = None
joint_attention_kwargs: Optional[Dict[str, Any]] = None
clip_skip: Optional[int] = None
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
callback_on_step_end_tensor_inputs: List[str] = None
def to_dict(self) -> Dict[str, Any]:
# NOTE(@lry89757) Only return the dict that not the default value None
result = {}
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
result[field.name] = value
return result
@classmethod
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
return cls(**kwargs)

View File

@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
class BaseEngine(ABC):
@abstractmethod
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
pass
@abstractmethod
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
"""
Init Model for Engine
"""
@abstractmethod
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
"""
Generate ouptput for coming requests
"""
@abstractmethod
def add_request(self, prompts, request_ids=None, **kwargs):
"""
Add new request to Engine
"""
@abstractmethod
def step(self):
"""
Perform one new step forward
"""
@abstractmethod
def _verify_args(self):
"""
Verify the parameters and members of class
"""
@torch.inference_mode()
def capture_model(self):
"""
Use cuda graph to capture model
"""
return NotImplementedError("This method should be implemented by subclasses")
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
**kwargs,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model

View File

@ -0,0 +1,200 @@
from itertools import count
from typing import List, Tuple, Type, Union
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from torch import distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .base_engine import BaseEngine
from .request_handler import NaiveRequestHandler
PP_AXIS, TP_AXIS = 0, 1
class DiffusionEngine(BaseEngine):
def __init__(
self,
model_or_path: DiffusionPipeline | str,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Policy | type[Policy] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.model_type = get_model_type(model_or_path=model_or_path)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.request_handler = NaiveRequestHandler()
self.counter = count()
self._verify_args()
def _verify_args(self) -> None:
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
def init_model(
self,
model_or_path: Union[str, nn.Module, DiffusionPipeline],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
policy_map_key = model.__class__.__name__
model = DiffusionPipe(model)
elif isinstance(model_or_path, DiffusionPipeline):
policy_map_key = model_or_path.__class__.__name__
model = DiffusionPipe(model_or_path)
else:
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
model_policy = model_policy_map.get(policy_map_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = model.to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
generation_config: DiffusionGenerationConfig = None,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
""" """
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
**gen_config_dict,
**kwargs,
)
output_reqs_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
while self.request_handler.check_unfinished_reqs():
output_reqs_list += self.step()
return output_reqs_list
def add_request(
self,
prompts: Union[List[str], str],
request_ids: Union[List[int], int] = None,
**kwargs,
):
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if not isinstance(prompts, list):
prompts = [prompts]
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
prompts_num = len(prompts)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
self.request_handler.add_sequence(seq)
def step(self) -> List[PIL.Image.Image]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. run forward to get List[Image]
Returns:
List[PIL.Image.Image]: Image Generated by one step.
"""
input = self.request_handler.schedule()
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
return ret

View File

@ -1,57 +1,24 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import List, Tuple, Type, Union
import numpy as np
import torch
import PIL.Image
import torch.nn as nn
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from diffusers import DiffusionPipeline
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.inference.config import InferenceConfig
from colossalai.inference.utils import ModelType, get_model_type
from colossalai.shardformer.policies.base_policy import Policy
from .request_handler import RequestHandler
__all__ = ["InferenceEngine"]
PP_AXIS, TP_AXIS = 0, 1
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
"""
InferenceEngine which manages the inference process..
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model.
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
@ -60,567 +27,68 @@ class InferenceEngine:
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
model_or_path: Union[nn.Module, str, DiffusionPipeline],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
self.model_type = get_model_type(model_or_path=model_or_path)
self.engine = None
if self.model_type == ModelType.LLM:
from .llm_engine import LLMEngine
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.engine = LLMEngine(
model_or_path=model_or_path,
tokenizer=tokenizer,
inference_config=inference_config,
verbose=verbose,
model_policy=model_policy,
)
elif self.model_type == ModelType.DIFFUSION_MODEL:
from .diffusion_engine import DiffusionEngine
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.engine = DiffusionEngine(
model_or_path=model_or_path,
inference_config=inference_config,
verbose=verbose,
model_policy=model_policy,
)
elif self.model_type == ModelType.UNKNOWN:
self.logger.error(f"Model Type either Difffusion or LLM!")
self._initialized = True
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
if arch is "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
assert self.engine is not None, "Please init Engine first"
assert self._initialized, "Engine must be initialized"
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
*args,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
assert self.engine is not None, "Please init Engine first"
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
*args,
**kwargs,
) -> None:
"""
@ -630,168 +98,36 @@ class InferenceEngine:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
kwargs: for LLM, it could be max_length, max_new_tokens, etc
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
"""
assert self.engine is not None, "Please init Engine first"
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
# apply the prompt template to the input prompts
def step(self):
assert self.engine is not None, "Please init Engine first"
return self.engine.step()
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert (
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
def __getattr__(self, name):
"""
The Design logic of getattr, setattr:
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
"""
if self.__dict__.get("_initialized", False):
if name in self.__dict__:
return self.__dict__[name]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
return getattr(self.engine, name)
else:
return self.__dict__[name]
def __setattr__(self, name, value):
if self.__dict__.get("_initialized", False):
if name in self.__dict__:
self.__dict__[name] = value
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
setattr(self.engine, name, value)
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
self.__dict__[name] = value

View File

@ -0,0 +1,758 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .base_engine import BaseEngine
from .request_handler import RequestHandler
PP_AXIS, TP_AXIS = 0, 1
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class LLMEngine(BaseEngine):
"""
InferenceEngine which manages the inference process..
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
if arch == "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_reqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_reqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None:
"""
Add requests.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert (
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences

View File

@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -98,7 +98,46 @@ class RunningList:
self._decoding[seq_id] = self._prefill.pop(seq_id)
class RequestHandler:
class NaiveRequestHandler:
def __init__(self) -> None:
self.running_list: List[DiffusionSequence] = []
self.waiting_list: List[str] = []
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def _has_running(self) -> bool:
return any(lst for lst in self.running_list)
def check_unfinished_reqs(self):
return self._has_waiting() or self._has_running()
def add_sequence(self, seq: DiffusionSequence):
"""
Add the request to waiting list.
"""
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
self.waiting_list.append(seq)
def _find_sequence(self, request_id: int) -> DiffusionSequence:
"""
Find the request by request_id.
"""
for lst in enumerate(self.waiting_list + self.running_list):
for seq in lst:
if seq.request_id == request_id:
return seq
return None
def schedule(self):
ret = None
if self._has_waiting:
ret = self.waiting_list[0]
self.waiting_list = self.waiting_list[1:]
return ret
class RequestHandler(NaiveRequestHandler):
"""
RequestHandler is the core for handling existing requests and updating current batch.
During generation process, we call schedule function each iteration to update current batch.
@ -176,12 +215,12 @@ class RequestHandler:
generated_token_size=inference_config.generated_token_size,
)
def _has_running(self) -> bool:
return not self.running_bb.is_empty()
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
@ -318,7 +357,7 @@ class RequestHandler:
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:
def check_unfinished_reqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def total_requests_in_batch_bucket(self) -> int:

View File

@ -0,0 +1,54 @@
import inspect
import types
import torch
from torch import nn
class DiffusionPipe(nn.Module):
"""
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
"""
def __init__(self, source_obj) -> None:
super(DiffusionPipe, self).__init__()
for k, v in source_obj.__dict__.items():
if isinstance(v, nn.Module):
self.add_module(k, v)
else:
setattr(self, k, v)
skip_list = ["_execution_device", "to", "device"] # this
for name, member in inspect.getmembers(source_obj.__class__):
if name in skip_list:
continue
if not name.startswith("__") and not name.endswith("__"):
if isinstance(member, property):
setattr(self.__class__, name, member)
elif inspect.isfunction(member) or inspect.ismethod(member):
bound_method = types.MethodType(member, self)
setattr(self, name, bound_method)
elif not callable(member) and not isinstance(member, property):
setattr(self, name, member)
elif name == "__call__":
bound_method = types.MethodType(member, self)
setattr(self, "_forward", bound_method)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
# return self.device
return torch.device("cuda")
@property
def device(self):
next(self.parameters()).device
def forward(self, *args, **kwargs):
return self._forward(*args, **kwargs)

View File

@ -0,0 +1,626 @@
# Code refer and adapted from:
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
# https://github.com/PipeFusion/PipeFusion
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.models import attention_processor
from diffusers.models.attention import Attention
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from torch import nn
from torch.distributed import ProcessGroup
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.utils import get_current_device
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
logger = get_dist_logger(__name__)
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
def PixArtAlphaTransformer2DModel_forward(
self: PixArtTransformer2DModel,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
assert hasattr(
self, "patched_parallel_size"
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.config.patch_size,
hidden_states.shape[-1] // self.config.patch_size,
)
hidden_states = self.pos_embed(hidden_states)
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
2, dim=1
)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height // self.patched_parallel_size,
width,
self.config.patch_size,
self.config.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
height // self.patched_parallel_size * self.config.patch_size,
width * self.config.patch_size,
)
)
# enable Distrifusion Optimization
if hasattr(self, "patched_parallel_size"):
from torch import distributed as dist
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
self.output_buffer = torch.empty_like(output)
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
return (output,)
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
def SD3Transformer2DModel_forward(
self: SD3Transformer2DModel,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor]:
assert hasattr(
self, "patched_parallel_size"
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size // self.patched_parallel_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
# enable Distrifusion Optimization
if hasattr(self, "patched_parallel_size"):
from torch import distributed as dist
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
self.output_buffer = torch.empty_like(output)
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
return (output,)
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
class DistrifusionPatchEmbed(ParallelModule):
def __init__(
self,
module: PatchEmbed,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.module = module
self.rank = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
@staticmethod
def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
distrifusion_embed = DistrifusionPatchEmbed(
module, process_group, model_shard_infer_config=model_shard_infer_config
)
return distrifusion_embed
def forward(self, latent):
module = self.module
if module.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
latent = module.proj(latent)
if module.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if module.layer_norm:
latent = module.norm(latent)
if module.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate or crop positional embeddings as needed
if module.pos_embed_max_size:
pos_embed = module.cropped_pos_embed(height, width)
else:
if module.height != height or module.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=module.pos_embed.shape[-1],
grid_size=(height, width),
base_size=module.base_size,
interpolation_scale=module.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
else:
pos_embed = module.pos_embed
b, c, h = pos_embed.shape
pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
return (latent + pos_embed).to(latent.dtype)
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
class DistrifusionConv2D(ParallelModule):
def __init__(
self,
module: nn.Conv2d,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.module = module
self.rank = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
@staticmethod
def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
return distrifusion_conv
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
stride = self.module.stride[0]
padding = self.module.padding[0]
output_h = x.shape[2] // stride // self.patched_parallelism_size
idx = dist.get_rank()
h_begin = output_h * idx * stride - padding
h_end = output_h * (idx + 1) * stride + padding
final_padding = [padding, padding, 0, 0]
if h_begin < 0:
h_begin = 0
final_padding[2] = padding
if h_end > h:
h_end = h
final_padding[3] = padding
sliced_input = x[:, :, h_begin:h_end, :]
padded_input = F.pad(sliced_input, final_padding, mode="constant")
return F.conv2d(
padded_input,
self.module.weight,
self.module.bias,
stride=stride,
padding="valid",
)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = self.sliced_forward(input)
return output
# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
class DistrifusionFusedAttention(ParallelModule):
def __init__(
self,
module: attention_processor.Attention,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.counter = 0
self.module = module
self.buffer_list = None
self.kv_buffer_idx = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
self.handle = None
self.process_group = process_group
self.warm_step = 5 # for warmup
@staticmethod
def from_native_module(
module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
return DistrifusionFusedAttention(
module=module,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
)
def _forward(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
if self.patched_parallelism_size == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
elif self.counter <= self.warm_step:
# logger.info(f"warmup: {self.counter}")
dist.all_gather(
self.buffer_list,
kv,
group=self.process_group,
async_op=False,
)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
# logger.info(f"use old kv to infer: {self.counter}")
self.buffer_list[self.kv_buffer_idx].copy_(kv)
full_kv = torch.cat(self.buffer_list, dim=1)
assert self.handle is None, "we should maintain the kv of last step"
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
if self.handle is not None:
self.handle.wait()
self.handle = None
b, l, c = hidden_states.shape
kv_shape = (b, l, self.module.to_k.out_features * 2)
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
self.buffer_list = [
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
for _ in range(self.patched_parallelism_size)
]
self.counter = 0
attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
output = self._forward(
self.module,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
self.counter += 1
return output
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
class DistriSelfAttention(ParallelModule):
def __init__(
self,
module: Attention,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.counter = 0
self.module = module
self.buffer_list = None
self.kv_buffer_idx = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
self.handle = None
self.process_group = process_group
self.warm_step = 3 # for warmup
@staticmethod
def from_native_module(
module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
return DistriSelfAttention(
module=module,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
)
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
attn = self.module
assert isinstance(attn, Attention)
residual = hidden_states
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)
encoder_hidden_states = hidden_states
k = self.module.to_k(encoder_hidden_states)
v = self.module.to_v(encoder_hidden_states)
kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
if self.patched_parallelism_size == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
elif self.counter <= self.warm_step:
# logger.info(f"warmup: {self.counter}")
dist.all_gather(
self.buffer_list,
kv,
group=self.process_group,
async_op=False,
)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
# logger.info(f"use old kv to infer: {self.counter}")
self.buffer_list[self.kv_buffer_idx].copy_(kv)
full_kv = torch.cat(self.buffer_list, dim=1)
assert self.handle is None, "we should maintain the kv of last step"
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
if HAS_FLASH_ATTN:
# flash attn
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
else:
# naive attn
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
# async preallocates memo buffer
if self.handle is not None:
self.handle.wait()
self.handle = None
b, l, c = hidden_states.shape
kv_shape = (b, l, self.module.to_k.out_features * 2)
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
self.buffer_list = [
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
for _ in range(self.patched_parallelism_size)
]
self.counter = 0
output = self._forward(hidden_states, scale=scale)
self.counter += 1
return output

View File

@ -0,0 +1,220 @@
# Code adapted from:
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
from typing import Callable, List, Optional, Union
import PIL.Image
import torch
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from colossalai.logging import get_dist_logger
from ..layers.diffusion import DiffusionPipe
logger = get_dist_logger(__name__)
@torch.no_grad()
def pixart_alpha_forward(
self: DiffusionPipe,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> PIL.Image:
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
# self.maybe_free_model_hooks()
return image

View File

@ -0,0 +1,178 @@
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from ..layers.diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output
@torch.no_grad()
def sd3_forward(
self: DiffusionPipe,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
return image

View File

@ -1,16 +1,22 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
from .pixart_alpha import PixArtAlphaInferPolicy
from .stablediffusion3 import StableDiffusion3InferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
}
__all__ = [
"NoPaddingLlamaModelInferPolicy",
"NoPaddingBaichuanModelInferPolicy",
"GlideLlamaModelPolicy",
"StableDiffusion3InferPolicy",
"PixArtAlphaInferPolicy",
"model_polic_map",
]

Some files were not shown because too many files have changed in this diff Show More