mirror of https://github.com/hpcaitech/ColossalAI
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
pull/5976/head
parent
53cb9606bd
commit
0c10afd372
|
@ -1 +1,3 @@
|
|||
2.1.0-12.1.0
|
||||
2.2.2-12.1.0
|
||||
2.3.0-12.1.0
|
||||
|
|
|
@ -55,41 +55,27 @@ jobs:
|
|||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
apt update && apt install -y cmake
|
||||
pip install -r requirements.txt
|
||||
DISABLE_URING=1 pip install -v .
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Download cub for CUDA 10.2
|
||||
run: |
|
||||
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
|
||||
|
||||
# check if it is CUDA 10.2
|
||||
# download cub
|
||||
if [ "$CUDA_VERSION" = "10.2" ]; then
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
fi
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --durations=0 tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
|
|
@ -49,42 +49,27 @@ jobs:
|
|||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
apt update && apt install -y cmake
|
||||
pip install -r requirements.txt
|
||||
DISABLE_URING=1 pip install -v .
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Download cub for CUDA 10.2
|
||||
run: |
|
||||
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
|
||||
|
||||
# check if it is CUDA 10.2
|
||||
# download cub
|
||||
if [ "$CUDA_VERSION" = "10.2" ]; then
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
fi
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --durations=0 tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
|
|
@ -43,47 +43,28 @@ jobs:
|
|||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt update && apt install -y cmake
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
apt update && apt install -y cmake
|
||||
pip install -r requirements.txt
|
||||
DISABLE_URING=1 pip install -v .
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
|
||||
- name: Download cub for CUDA 10.2
|
||||
run: |
|
||||
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
|
||||
|
||||
# check if it is CUDA 10.2
|
||||
# download cub
|
||||
if [ "$CUDA_VERSION" = "10.2" ]; then
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
fi
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --durations=0 tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ jobs:
|
|||
mkdir sft_data
|
||||
mkdir prompt_data
|
||||
mkdir preference_data
|
||||
mkdir kto_data
|
||||
./tests/test_data_preparation.sh
|
||||
./tests/test_train.sh
|
||||
env:
|
||||
|
@ -61,3 +62,4 @@ jobs:
|
|||
SFT_DATASET: ./sft_data
|
||||
PROMPT_DATASET: ./prompt_data
|
||||
PREFERENCE_DATASET: ./preference_data
|
||||
KTO_DATASET: ./kto_data
|
||||
|
|
|
@ -10,7 +10,7 @@ import math
|
|||
import os
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from colossal_llama.dataset.conversation import LLaMA2_Conv
|
||||
from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv
|
||||
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer
|
||||
|
@ -75,6 +75,8 @@ def main():
|
|||
# Prepare to the tokenizer.
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
|
||||
default_conversation = LLaMA3_Conv
|
||||
|
||||
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
|
||||
if args.llama_version == 2:
|
||||
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
|
||||
|
|
|
@ -128,6 +128,12 @@ def main() -> None:
|
|||
parser.add_argument("--zero", type=int, default=1)
|
||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||
parser.add_argument(
|
||||
"--skip_save_each_epoch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="skip saving the model checkpoint after each epoch is completed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
|
@ -370,11 +376,17 @@ def main() -> None:
|
|||
)
|
||||
total_loss.fill_(0.0)
|
||||
pbar.update()
|
||||
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
save_model_condition = (
|
||||
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||
)
|
||||
|
||||
if not args.skip_save_each_epoch:
|
||||
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||
|
||||
if save_model_condition:
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
if args.use_neft:
|
||||
|
|
|
@ -146,6 +146,9 @@ docs/.build
|
|||
examples/wandb/
|
||||
examples/logs/
|
||||
examples/output/
|
||||
examples/training_scripts/logs
|
||||
examples/training_scripts/wandb
|
||||
examples/training_scripts/output
|
||||
|
||||
examples/awesome-chatgpt-prompts/
|
||||
temp/
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
- [Open QA](#open-qa)
|
||||
- [Limitation for LLaMA-finetuned models](#limitation)
|
||||
- [Limitation of dataset](#limitation)
|
||||
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
||||
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
||||
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||
- [FAQ](#faq)
|
||||
- [How to save/load checkpoint](#faq)
|
||||
- [How to train with limited resources](#faq)
|
||||
|
@ -135,17 +139,15 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
|
|||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Are you looking for practical joke ideas?"
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
|
@ -171,23 +173,20 @@ Below shows the preference dataset format used in training the reward model.
|
|||
"from": "human",
|
||||
"content": "Introduce butterflies species in Oregon."
|
||||
}
|
||||
]
|
||||
],
|
||||
"chosen": [
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
|
||||
},
|
||||
...
|
||||
],
|
||||
"rejected": [
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
|
@ -216,7 +215,6 @@ PPO uses two kind of training data--- the prompt data and the sft data (optional
|
|||
"from": "human",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
}
|
||||
...
|
||||
]
|
||||
},
|
||||
]
|
||||
|
@ -262,9 +260,8 @@ experience buffer size
|
|||
= train_batch_size * accumulation_steps * num_tp_group
|
||||
```
|
||||
|
||||
## Alternative Option For RLHF: Direct Preference Optimization
|
||||
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
|
||||
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### DPO Training Stage1 - Supervised Instructs Tuning
|
||||
|
||||
|
@ -277,6 +274,15 @@ For DPO training, you only need the preference dataset. Please follow the instru
|
|||
#### Step 2: Training
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
|
||||
|
||||
## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
|
||||
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
|
||||
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### Inference Quantization and Serving - After Training
|
||||
|
||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||
|
@ -441,20 +447,6 @@ If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be
|
|||
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
|
||||
</details>
|
||||
|
||||
## The Plan
|
||||
|
||||
- [x] implement PPO fine-tuning
|
||||
- [x] implement training reward model
|
||||
- [x] support LoRA
|
||||
- [x] support inference
|
||||
- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
|
||||
- [x] implement PPO-ptx fine-tuning
|
||||
- [x] support flash-attention
|
||||
- [x] implement DPO fine-tuning
|
||||
- [ ] integrate with Ray
|
||||
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
|
||||
- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
|
||||
|
||||
### Real-time progress
|
||||
|
||||
You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
|
||||
|
@ -522,7 +514,7 @@ Coati is developed by ColossalAI Team:
|
|||
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
|
||||
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
|
||||
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
|
||||
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO.
|
||||
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
|
||||
|
||||
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
|
||||
- [Zangwei Zheng](https://github.com/zhengzangw)
|
||||
|
@ -572,6 +564,36 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github
|
|||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}},
|
||||
}
|
||||
|
||||
@misc{meng2024simposimplepreferenceoptimization,
|
||||
title={SimPO: Simple Preference Optimization with a Reference-Free Reward},
|
||||
author={Yu Meng and Mengzhou Xia and Danqi Chen},
|
||||
year={2024},
|
||||
eprint={2405.14734},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2405.14734},
|
||||
}
|
||||
|
||||
@misc{rafailov2023directpreferenceoptimizationlanguage,
|
||||
title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
|
||||
author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
|
||||
year={2023},
|
||||
eprint={2305.18290},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2305.18290},
|
||||
}
|
||||
|
||||
@misc{hong2024orpomonolithicpreferenceoptimization,
|
||||
title={ORPO: Monolithic Preference Optimization without Reference Model},
|
||||
author={Jiwoo Hong and Noah Lee and James Thorne},
|
||||
year={2024},
|
||||
eprint={2403.07691},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2403.07691},
|
||||
}
|
||||
```
|
||||
|
||||
## Licenses
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
|
||||
DATASET_SIZE=320
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
declare -a dataset=(
|
||||
$BENCHMARK_DATA_DIR/arrow/part-0
|
||||
)
|
||||
|
||||
# Generate dummy test data
|
||||
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
|
||||
|
||||
|
||||
colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2_cpu" \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 4 \
|
||||
--lr 1e-6 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="kto"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
|
||||
DATASET_SIZE=80
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
declare -a dataset=(
|
||||
$BENCHMARK_DATA_DIR/arrow/part-0
|
||||
)
|
||||
|
||||
# Generate dummy test data
|
||||
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto
|
||||
|
||||
|
||||
colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2_cpu" \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 2 \
|
||||
--lr 1e-5 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
PROJECT_NAME="orpo"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data
|
||||
DATASET_SIZE=160
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
declare -a dataset=(
|
||||
$BENCHMARK_DATA_DIR/arrow/part-0
|
||||
)
|
||||
|
||||
# Generate dummy test data
|
||||
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
|
||||
|
||||
|
||||
colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_orpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 4 \
|
||||
--lr 8e-6 \
|
||||
--lam 0.5 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,50 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data
|
||||
DATASET_SIZE=640
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
declare -a dataset=(
|
||||
$BENCHMARK_DATA_DIR/arrow/part-0
|
||||
)
|
||||
|
||||
|
||||
# Generate dummy test data
|
||||
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type sft
|
||||
|
||||
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 1 --master_port 31312 ../examples/training_scripts/train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin zero2 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 5e-5 \
|
||||
--lora_rank 32 \
|
||||
--max_len 2048 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="simpo"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data
|
||||
DATASET_SIZE=640
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
declare -a dataset=(
|
||||
$BENCHMARK_DATA_DIR/arrow/part-0
|
||||
)
|
||||
|
||||
# Generate dummy test data
|
||||
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
|
||||
|
||||
|
||||
colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2_cpu" \
|
||||
--loss_type "simpo_loss" \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
--lr 1e-6 \
|
||||
--beta 0.1 \
|
||||
--gamma 0.6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--disable_reference_model \
|
||||
--length_normalization \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,30 @@
|
|||
from typing import Callable
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class DummyLLMDataset(Dataset):
|
||||
def __init__(self, keys, seq_len, size=500, gen_fn={}):
|
||||
self.keys = keys
|
||||
self.gen_fn = gen_fn
|
||||
self.seq_len = seq_len
|
||||
self.data = self._generate_data()
|
||||
self.size = size
|
||||
|
||||
def _generate_data(self):
|
||||
data = {}
|
||||
for key in self.keys:
|
||||
if key in self.gen_fn:
|
||||
data[key] = self.gen_fn[key]
|
||||
else:
|
||||
data[key] = [1] * self.seq_len
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
|
||||
for key in self.keys
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from datasets import load_dataset
|
||||
from dummy_dataset import DummyLLMDataset
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The output dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_size",
|
||||
type=int,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The size of data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The max length of data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_type",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The type of data, choose one from ['sft', 'prompt', 'preference', 'kto']",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.data_type == "sft":
|
||||
dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_length, args.dataset_size)
|
||||
elif args.data_type == "prompt":
|
||||
# pass PPO dataset is prepared separately
|
||||
pass
|
||||
elif args.data_type == "preference":
|
||||
dataset = DummyLLMDataset(
|
||||
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
|
||||
args.max_length,
|
||||
args.dataset_size,
|
||||
)
|
||||
elif args.data_type == "kto":
|
||||
dataset = DummyLLMDataset(
|
||||
["prompt", "completion", "label"],
|
||||
args.max_length - 512,
|
||||
args.dataset_size,
|
||||
gen_fn={
|
||||
"completion": lambda x: [1] * 512,
|
||||
"label": lambda x: x % 2,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown data type {args.data_type}")
|
||||
|
||||
# Save each jsonl spliced dataset.
|
||||
output_index = "0"
|
||||
output_name = f"part-{output_index}"
|
||||
os.makedirs(args.data_dir, exist_ok=True)
|
||||
output_jsonl_path = os.path.join(args.data_dir, "json")
|
||||
output_arrow_path = os.path.join(args.data_dir, "arrow")
|
||||
output_cache_path = os.path.join(args.data_dir, "cache")
|
||||
os.makedirs(output_jsonl_path, exist_ok=True)
|
||||
os.makedirs(output_arrow_path, exist_ok=True)
|
||||
output_jsonl_file_path = os.path.join(output_jsonl_path, output_name + ".jsonl")
|
||||
st = time.time()
|
||||
with open(file=output_jsonl_file_path, mode="w", encoding="utf-8") as fp_writer:
|
||||
count = 0
|
||||
for i in range(len(dataset)):
|
||||
data_point = dataset[i]
|
||||
if count % 500 == 0:
|
||||
logger.info(f"processing {count} spliced data points for {fp_writer.name}")
|
||||
count += 1
|
||||
fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
|
||||
logger.info(
|
||||
f"Current file {fp_writer.name}; "
|
||||
f"Data size: {len(dataset)}; "
|
||||
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
|
||||
)
|
||||
# Save each arrow spliced dataset
|
||||
output_arrow_file_path = os.path.join(output_arrow_path, output_name)
|
||||
logger.info(f"Start to save {output_arrow_file_path}")
|
||||
dataset = load_dataset(
|
||||
path="json",
|
||||
data_files=[output_jsonl_file_path],
|
||||
cache_dir=os.path.join(output_cache_path, "tokenized"),
|
||||
keep_in_memory=False,
|
||||
num_proc=cpu_count(),
|
||||
split="train",
|
||||
)
|
||||
dataset.save_to_disk(dataset_path=output_arrow_file_path, num_proc=min(len(dataset), cpu_count()))
|
|
@ -1,24 +1,26 @@
|
|||
from .conversation import Conversation, setup_conversation_template
|
||||
from .loader import (
|
||||
DataCollatorForKTODataset,
|
||||
DataCollatorForPreferenceDataset,
|
||||
DataCollatorForPromptDataset,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
)
|
||||
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||
|
||||
__all__ = [
|
||||
"tokenize_prompt_dataset",
|
||||
"tokenize_prompt",
|
||||
"DataCollatorForPromptDataset",
|
||||
"is_rank_0",
|
||||
"DataCollatorForPreferenceDataset",
|
||||
"DataCollatorForSupervisedDataset",
|
||||
"DataCollatorForKTODataset",
|
||||
"StatefulDistributedSampler",
|
||||
"load_tokenized_dataset",
|
||||
"supervised_tokenize_pretrain",
|
||||
"supervised_tokenize_sft",
|
||||
"tokenize_sft",
|
||||
"tokenize_rlhf",
|
||||
"tokenize_kto",
|
||||
"setup_conversation_template",
|
||||
"Conversation",
|
||||
]
|
||||
|
|
|
@ -18,6 +18,7 @@ class Conversation:
|
|||
chat_template: str
|
||||
stop_ids: List[int]
|
||||
end_of_assistant: str
|
||||
roles = ["user", "assistant"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
|
||||
|
@ -85,7 +86,7 @@ class Conversation:
|
|||
Raises:
|
||||
AssertionError: If the role is not 'user' or 'assistant'.
|
||||
"""
|
||||
assert role in ["user", "assistant"]
|
||||
assert role in self.roles
|
||||
self.messages.append({"role": role, "content": message})
|
||||
|
||||
def copy(self):
|
||||
|
|
|
@ -28,6 +28,8 @@ def load_tokenized_dataset(
|
|||
Each instance of dataset is a dictionary with
|
||||
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
|
||||
"""
|
||||
if not dataset_paths:
|
||||
return None
|
||||
mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
|
||||
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
|
||||
|
||||
|
@ -233,6 +235,91 @@ class DataCollatorForPreferenceDataset(object):
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForKTODataset(object):
|
||||
"""
|
||||
Collate instances for kto dataset.
|
||||
Each input instance is a tokenized dictionary with fields
|
||||
`prompt`(List[int]), `completion`(List[int]) and `label`(bool).
|
||||
Each output instance is a tokenized dictionary with fields
|
||||
`kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
|
||||
`input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
instances (`Sequence[Dict[str, List[int]]]`):
|
||||
Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
|
||||
`prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).
|
||||
|
||||
Returns:
|
||||
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
|
||||
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
||||
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
||||
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
|
||||
"""
|
||||
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
|
||||
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
|
||||
f"but now `{self.tokenizer.pad_token_id}`"
|
||||
)
|
||||
# prepare the preference data
|
||||
prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
|
||||
prompt_zeros = [torch.zeros_like(t) for t in prompt]
|
||||
completion = [torch.LongTensor(instance["completion"]) for instance in instances]
|
||||
completion_ones = [torch.ones_like(t) for t in completion]
|
||||
label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
|
||||
input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
|
||||
loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
|
||||
# right padding
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
loss_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=loss_mask, batch_first=True, padding_value=0
|
||||
) # (bsz, max_len)
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
|
||||
# prepare kt data
|
||||
kl_completion = completion[::-1] # y'
|
||||
kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
|
||||
kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
|
||||
kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
|
||||
# right padding
|
||||
kl_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=kl_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=kl_loss_mask, batch_first=True, padding_value=0
|
||||
) # (bsz, max_len)
|
||||
to_pad = self.max_length - kl_input_ids.size(1)
|
||||
kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
|
||||
kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
data_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"loss_mask": loss_mask,
|
||||
"label": torch.stack(label),
|
||||
"kl_input_ids": kl_input_ids,
|
||||
"kl_attention_mask": kl_attention_mask,
|
||||
"kl_loss_mask": kl_loss_mask,
|
||||
}
|
||||
return data_dict
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -23,11 +23,10 @@ IGNORE_INDEX = -100
|
|||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
|
||||
|
||||
def supervised_tokenize_sft(
|
||||
def tokenize_sft(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
|
@ -39,51 +38,37 @@ def supervised_tokenize_sft(
|
|||
|
||||
Args:
|
||||
data_point: the data point of the following format
|
||||
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
tokenizer: the tokenizer whose
|
||||
conversation_template: the conversation template to apply
|
||||
ignore_index: the ignore index when calculate loss during training
|
||||
max_length: the maximum context length
|
||||
"""
|
||||
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
for mess in messages:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
template.append_message(from_str, mess["content"])
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
if len(template.messages) % 2 != 0:
|
||||
# Force to end with assistant response
|
||||
template.messages = template.messages[0:-1]
|
||||
|
||||
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
||||
turns = [i for i in range(1, len(messages) // 2 + 1)]
|
||||
|
||||
lo, hi = 0, len(turns)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if max_length - 1 < len(
|
||||
tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
|
||||
):
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 1
|
||||
target_turn_index = lo
|
||||
|
||||
# The tokenized length for first turn already exceeds `max_length - 1`.
|
||||
if target_turn_index - 1 < 0:
|
||||
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
|
||||
# tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
|
||||
prompt = template.get_prompt()
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages, prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
|
||||
if tokenized is None:
|
||||
return dict(
|
||||
input_ids=None,
|
||||
labels=None,
|
||||
|
@ -93,44 +78,18 @@ def supervised_tokenize_sft(
|
|||
seq_category=None,
|
||||
)
|
||||
|
||||
target_turn = turns[target_turn_index - 1]
|
||||
prompt = template.get_prompt(2 * target_turn)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
|
||||
labels = [ignore_index] * len(tokenized)
|
||||
for start, end in zip(starts, ends):
|
||||
if end == len(tokenized):
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
labels = labels + [ignore_index]
|
||||
labels[start:end] = tokenized[start:end]
|
||||
|
||||
# truncate the sequence at the last token that requires loss calculation
|
||||
to_truncate_len = 0
|
||||
for i in range(len(tokenized) - 1, -1, -1):
|
||||
if labels[i] == ignore_index:
|
||||
to_truncate_len += 1
|
||||
else:
|
||||
break
|
||||
tokenized = tokenized[: len(tokenized) - to_truncate_len]
|
||||
labels = labels[: len(labels) - to_truncate_len]
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
# Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
# Some chat templates already include bos token
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
labels = [ignore_index] + labels
|
||||
labels = [-100] + labels
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
# Force to add eos token at the end of the tokenized sequence
|
||||
if tokenized[-1] != tokenizer.eos_token_id:
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
labels = labels + [tokenizer.eos_token_id]
|
||||
else:
|
||||
labels[-1] = tokenizer.eos_token_id
|
||||
|
||||
# For some model without bos/eos may raise the following errors
|
||||
# log decoded inputs and labels for debugging
|
||||
inputs_decode = tokenizer.decode(tokenized)
|
||||
start = 0
|
||||
end = 0
|
||||
|
@ -167,11 +126,10 @@ def supervised_tokenize_sft(
|
|||
)
|
||||
|
||||
|
||||
def tokenize_prompt_dataset(
|
||||
def tokenize_prompt(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
|
@ -179,48 +137,39 @@ def tokenize_prompt_dataset(
|
|||
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
|
||||
Args:
|
||||
data_point: the data point of the following format
|
||||
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
tokenizer: the tokenizer whose
|
||||
conversation_template: the conversation template to apply
|
||||
ignore_index: the ignore index when calculate loss during training
|
||||
max_length: the maximum context length
|
||||
"""
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
for mess in messages:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
template.append_message(from_str, mess["content"])
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
||||
target_turn = len(template.messages)
|
||||
if target_turn % 2 != 1:
|
||||
if len(template.messages) % 2 != 1:
|
||||
# exclude the answer if provided. keep only the prompt
|
||||
target_turn = target_turn - 1
|
||||
template.messages = template.messages[:-1]
|
||||
|
||||
# Prepare data
|
||||
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
|
||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
|
||||
# Skip overlength data
|
||||
if max_length - 1 < len(tokenized):
|
||||
if len(tokenized) > max_length:
|
||||
return dict(
|
||||
input_ids=None,
|
||||
inputs_decode=None,
|
||||
|
@ -231,47 +180,32 @@ def tokenize_prompt_dataset(
|
|||
# `inputs_decode` can be used to check whether the tokenization method is true.
|
||||
return dict(
|
||||
input_ids=tokenized,
|
||||
inputs_decode=tokenizer.decode(tokenized),
|
||||
inputs_decode=prompt,
|
||||
seq_length=len(tokenized),
|
||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||
)
|
||||
|
||||
|
||||
def apply_rlhf_data_format(
|
||||
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
|
||||
):
|
||||
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
||||
target_turn = int(len(template.messages) / 2)
|
||||
prompt = template.get_prompt(target_turn * 2)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
loss_mask = [0] * len(tokenized)
|
||||
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
||||
if mask_token is None:
|
||||
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
|
||||
# no truncation applied
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
|
||||
|
||||
loss_mask = [0] * len(tokenized)
|
||||
label_decode = []
|
||||
for start, end in zip(starts[-1:], ends[-1:]):
|
||||
# only the last round (chosen/rejected) counts
|
||||
if end == len(tokenized):
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
loss_mask = loss_mask + [1]
|
||||
loss_mask[start:end] = [1] * len(loss_mask[start:end])
|
||||
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
|
||||
# only the last round (chosen/rejected) is used to calculate loss
|
||||
for i in range(starts[-1], ends[-1]):
|
||||
loss_mask[i] = 1
|
||||
label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
loss_mask = [0] + loss_mask
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
# Force to add eos token at the end of the tokenized sequence
|
||||
if tokenized[-1] != tokenizer.eos_token_id:
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
loss_mask = loss_mask + [1]
|
||||
else:
|
||||
loss_mask[-1] = 1
|
||||
|
||||
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
|
||||
|
||||
|
||||
|
@ -279,39 +213,29 @@ def tokenize_rlhf(
|
|||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
A tokenization function to tokenize an original pretraining data point as following:
|
||||
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
|
||||
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
|
||||
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
|
||||
"""
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
context = data_point["context"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
for mess in context:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
|
||||
# Concate adjacent message from the same role
|
||||
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
|
||||
else:
|
||||
template.append_message(from_str, mess["content"])
|
||||
for idx, mess in enumerate(context):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{context}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
if len(template.messages) % 2 != 1:
|
||||
warnings.warn(
|
||||
"Please make sure leading context starts and ends with a line from human\nLeading context: "
|
||||
"Please make sure leading context starts and ends with a line from user\nLeading context: "
|
||||
+ str(template.messages)
|
||||
)
|
||||
return dict(
|
||||
|
@ -322,31 +246,27 @@ def tokenize_rlhf(
|
|||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
round_of_context = int((len(template.messages) - 1) / 2)
|
||||
|
||||
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
|
||||
assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
|
||||
chosen = deepcopy(template)
|
||||
rejected = deepcopy(template)
|
||||
chosen_continuation = data_point["chosen"]
|
||||
rejected_continuation = data_point["rejected"]
|
||||
for round in range(len(chosen_continuation)):
|
||||
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{chosen_continuation}"
|
||||
)
|
||||
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
|
||||
|
||||
for round in range(len(data_point["chosen"])):
|
||||
from_str = data_point["chosen"][round]["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
chosen.append_message(from_str, data_point["chosen"][round]["content"])
|
||||
|
||||
for round in range(len(data_point["rejected"])):
|
||||
from_str = data_point["rejected"][round]["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
rejected.append_message(from_str, data_point["rejected"][round]["content"])
|
||||
for round in range(len(rejected_continuation)):
|
||||
if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{rejected_continuation}"
|
||||
)
|
||||
rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
|
||||
|
||||
(
|
||||
chosen_input_ids,
|
||||
|
@ -356,48 +276,22 @@ def tokenize_rlhf(
|
|||
rejected_loss_mask,
|
||||
rejected_label_decode,
|
||||
) = (None, None, None, None, None, None)
|
||||
if (
|
||||
len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
|
||||
<= max_length - 1
|
||||
and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
|
||||
<= max_length - 1
|
||||
):
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
rejected_data_packed = apply_rlhf_data_format(
|
||||
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
||||
)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
rejected_data_packed["label_decode"],
|
||||
)
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
chosen_label_decode=None,
|
||||
rejected_input_ids=None,
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
rejected_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_input_ids,
|
||||
"chosen_loss_mask": chosen_loss_mask,
|
||||
"chosen_label_decode": chosen_label_decode,
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
else:
|
||||
if len(chosen_input_ids) > max_length or len(rejected_input_ids) > max_length:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
|
@ -406,3 +300,81 @@ def tokenize_rlhf(
|
|||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
chosen_label_decode=None,
|
||||
rejected_input_ids=None,
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_input_ids,
|
||||
"chosen_loss_mask": chosen_loss_mask,
|
||||
"chosen_label_decode": chosen_label_decode,
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
|
||||
|
||||
def tokenize_kto(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
Tokenize a dataset for KTO training
|
||||
The raw input data is conversation that have the following format
|
||||
{
|
||||
"prompt": [{"from": "user", "content": "xxx"}...],
|
||||
"completion": {"from": "assistant", "content": "xxx"},
|
||||
"label": true/false
|
||||
}
|
||||
It returns three fields
|
||||
The context, which contain the query and the assistant start,
|
||||
the completion, which only contains the assistance's answer,
|
||||
and a binary label, which indicates if the sample is prefered or not
|
||||
"""
|
||||
prompt = data_point["prompt"]
|
||||
completion = data_point["completion"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0].get("from", None) != "user":
|
||||
raise ValueError("conversation should start with user")
|
||||
if completion.get("from", None) != "assistant":
|
||||
raise ValueError("conversation should end with assistant")
|
||||
|
||||
for mess in prompt:
|
||||
if mess.get("from", None) == "user":
|
||||
template.append_message("user", mess["content"])
|
||||
elif mess.get("from", None) == "assistant":
|
||||
template.append_message("assistant", mess["content"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {mess.get('from', None)}")
|
||||
generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True)
|
||||
template.append_message("assistant", completion["content"])
|
||||
full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False)
|
||||
tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
|
||||
if len(tokenized_full_prompt) + 1 > max_length:
|
||||
return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None)
|
||||
tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"]
|
||||
tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :]
|
||||
tokenized_completion = deepcopy(tokenized_completion)
|
||||
if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id:
|
||||
tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt
|
||||
decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False)
|
||||
decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False)
|
||||
|
||||
return {
|
||||
"prompt": tokenized_generation_prompt,
|
||||
"completion": tokenized_completion,
|
||||
"label": data_point["label"],
|
||||
"input_id_decode": decoded_full_prompt,
|
||||
"completion_decode": decoded_completion,
|
||||
}
|
||||
|
|
|
@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
|
|||
return -1
|
||||
|
||||
|
||||
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
|
||||
def tokenize_and_concatenate(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text: List[str],
|
||||
require_loss: List[bool],
|
||||
max_length: int,
|
||||
discard_non_loss_tokens_at_tail: bool = True,
|
||||
):
|
||||
"""
|
||||
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
|
||||
|
||||
|
@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
|
|||
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
|
||||
text (List[str]): The list of texts to tokenize.
|
||||
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
|
||||
max_length: used to truncate the input ids
|
||||
discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
|
||||
|
||||
if the first round has already exeeded max length
|
||||
- if the user query already exeeded max length, discard the sample
|
||||
- if only the first assistant response exeeded max length, truncate the response to fit the max length
|
||||
else keep the first several complete rounds of the conversations until max length is reached
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
|
||||
|
@ -106,10 +119,18 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
|
|||
loss_ends = []
|
||||
for s, r in zip(text, require_loss):
|
||||
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
|
||||
if r:
|
||||
loss_starts.append(len(input_ids))
|
||||
loss_ends.append(len(input_ids) + len(tokenized))
|
||||
input_ids.extend(tokenized)
|
||||
if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
|
||||
if r:
|
||||
loss_starts.append(len(input_ids))
|
||||
loss_ends.append(len(input_ids) + len(tokenized))
|
||||
input_ids.extend(tokenized)
|
||||
if max_length and loss_starts[0] >= max_length:
|
||||
return None, None, None
|
||||
if discard_non_loss_tokens_at_tail:
|
||||
input_ids = input_ids[: loss_ends[-1]]
|
||||
if max_length:
|
||||
input_ids = input_ids[:max_length]
|
||||
loss_ends[-1] = min(max_length, loss_ends[-1])
|
||||
return input_ids, loss_starts, loss_ends
|
||||
|
||||
|
||||
|
@ -125,6 +146,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
|
|||
content_length = (
|
||||
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
|
||||
)
|
||||
# if the tokenized content start with a leading space, we want to keep it in loss calculation
|
||||
# e.g., Assistant: I am saying...
|
||||
# if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
|
||||
# e.g.,
|
||||
# Assistant: # '\n' as line breaker
|
||||
# I am saying...
|
||||
if prompt[first_occur - 1] != " ":
|
||||
chunks.append(prompt[start_idx:first_occur])
|
||||
chunks.append(prompt[first_occur : first_occur + content_length])
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from .base import BaseModel
|
||||
from .critic import Critic
|
||||
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
|
||||
from .lora import convert_to_lora_module
|
||||
from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from .lora import LoraConfig, convert_to_lora_module, lora_manager
|
||||
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from .reward_model import RewardModel
|
||||
from .utils import disable_dropout
|
||||
|
||||
|
@ -14,9 +14,11 @@ __all__ = [
|
|||
"ValueLoss",
|
||||
"LogSigLoss",
|
||||
"LogExpLoss",
|
||||
"LoraConfig",
|
||||
"lora_manager",
|
||||
"convert_to_lora_module",
|
||||
"DpoLoss",
|
||||
"generate",
|
||||
"KTOLoss" "generate",
|
||||
"generate_streaming",
|
||||
"disable_dropout",
|
||||
"update_model_kwargs_fn",
|
||||
|
|
|
@ -42,7 +42,6 @@ class BaseModel(nn.Module):
|
|||
out = self.model(dummy_input)
|
||||
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
|
||||
self.model = self.model.cpu()
|
||||
# print("self.last_hidden_state_size: ",self.last_hidden_state_size)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -5,10 +5,11 @@ LORA utils
|
|||
import dataclasses
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -18,148 +19,349 @@ logger = get_dist_logger()
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoRAManager:
|
||||
merge_weights: bool = False
|
||||
class LoraManager:
|
||||
able_to_merge: bool = True
|
||||
|
||||
|
||||
LORA_MANAGER = LoRAManager()
|
||||
lora_manager = LoraManager()
|
||||
|
||||
|
||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||
@dataclasses.dataclass
|
||||
class LoraConfig:
|
||||
r: int = 0
|
||||
lora_alpha: int = 32
|
||||
linear_lora_dropout: float = 0.1
|
||||
embedding_lora_dropout: float = 0.0
|
||||
lora_train_bias: str = "none"
|
||||
lora_initialization_method: str = "kaiming_uniform"
|
||||
target_modules: List = None
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_file: str):
|
||||
import json
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class LoraBase(lora.LoRALayer, nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
bias: Optional[nn.Parameter],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
fan_in_fan_out: bool = False,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
lora_initialization_method: str = "kaiming_uniform",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
self.r = r
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_dropout = nn.Dropout(lora_dropout)
|
||||
self.merged = False
|
||||
self.lora_initialization_method = lora_initialization_method
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.lora_A = None
|
||||
self.lora_B = None
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, "lora_A"):
|
||||
# Initialize A with the default values for nn.Linear and set B to zero.
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
):
|
||||
# Initialize A with the default values for nn.Linear and set B to zero.
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
elif self.lora_initialization_method == "PiSSA":
|
||||
# PiSSA method in this paper: https://arxiv.org/abs/2404.02948
|
||||
# Assume the SVD of the original weights is W = USV^T
|
||||
# Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
|
||||
# Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
|
||||
# self.scaling = 1.
|
||||
# SVD
|
||||
U, S, Vh = torch.svd_lowrank(
|
||||
self.weight.to(torch.float32).data, self.r, niter=4
|
||||
) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
|
||||
# weight_backup = self.weight.clone()
|
||||
|
||||
# Initialize A, B
|
||||
S = S / self.scaling
|
||||
self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
|
||||
self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
|
||||
# Initialize weight
|
||||
# To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
|
||||
self.weight.data = (
|
||||
((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
|
||||
)
|
||||
self.lora_A.requires_grad = True
|
||||
self.lora_B.requires_grad = True
|
||||
else:
|
||||
raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""
|
||||
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
||||
"""
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
self.training = mode
|
||||
if LORA_MANAGER.merge_weights:
|
||||
if mode and self.merged:
|
||||
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
||||
raise NotImplementedError("LoRA unmerge is not tested.")
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
||||
# FIXME(csric): temporary fix
|
||||
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
elif not mode and not self.merged:
|
||||
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, "lora_A")
|
||||
delattr(self, "lora_B")
|
||||
self.merged = True
|
||||
if mode and self.merged:
|
||||
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
||||
raise NotImplementedError("LoRA unmerge is not tested.")
|
||||
elif not mode and not self.merged and lora_manager.able_to_merge:
|
||||
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += self.lora_B @ self.lora_A * self.scaling
|
||||
delattr(self, "lora_A")
|
||||
delattr(self, "lora_B")
|
||||
self.merged = True
|
||||
|
||||
return self
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
class LoraLinear(LoraBase):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
bias: Union[nn.Parameter, bool],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.0,
|
||||
lora_initialization_method: str = "kaiming_uniform",
|
||||
):
|
||||
super().__init__(
|
||||
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
|
||||
)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
if bias is True:
|
||||
self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
|
||||
if bias is not None:
|
||||
self.bias.requires_grad = True
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
|
||||
self.lora_initialization_method = lora_initialization_method
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
|
||||
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.r > 0 and not self.merged:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
||||
result = F.linear(x, self.weight, bias=self.bias)
|
||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
return F.linear(x, self.weight, bias=self.bias)
|
||||
|
||||
|
||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
class LoraEmbedding(LoraBase):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
num_embeddings: int = None,
|
||||
embedding_dim: int = None,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
lora_initialization_method: str = "kaiming_uniform",
|
||||
):
|
||||
super().__init__(
|
||||
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
|
||||
)
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.weight = weight
|
||||
|
||||
in_features, out_features = num_embeddings, embedding_dim
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
|
||||
self.lora_initialization_method = lora_initialization_method
|
||||
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
|
||||
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
# reset parameters
|
||||
nn.init.zeros_(self.lora_A)
|
||||
nn.init.normal_(self.lora_B)
|
||||
|
||||
def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
|
||||
return F.embedding(
|
||||
x,
|
||||
weight,
|
||||
padding_idx=self.padding_idx,
|
||||
max_norm=self.max_norm,
|
||||
norm_type=self.norm_type,
|
||||
scale_grad_by_freq=self.scale_grad_by_freq,
|
||||
sparse=self.sparse,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
base_embedding = self._embed(x, self.weight)
|
||||
# base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
|
||||
if self.r > 0 and not self.merged:
|
||||
lora_A_embedding = self._embed(x, self.lora_A.t())
|
||||
embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
|
||||
return embedding
|
||||
else:
|
||||
return base_embedding
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""
|
||||
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
||||
"""
|
||||
|
||||
self.training = mode
|
||||
if mode and self.merged:
|
||||
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
||||
raise NotImplementedError("LoRA unmerge is not tested.")
|
||||
elif not mode and not self.merged and lora_manager.able_to_merge:
|
||||
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
|
||||
delattr(self, "lora_A")
|
||||
delattr(self, "lora_B")
|
||||
self.merged = True
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
|
||||
"""
|
||||
Wraps a linear layer with LoRA functionality.
|
||||
|
||||
Args:
|
||||
linear (nn.Linear): The linear layer to be wrapped.
|
||||
lora_rank (int): The rank of the LoRA decomposition.
|
||||
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
||||
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
||||
|
||||
Returns:
|
||||
LoraLinear: The wrapped linear layer with LoRA functionality.
|
||||
"""
|
||||
assert (
|
||||
lora_rank <= linear.in_features
|
||||
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
|
||||
lora_config.r <= linear.in_features
|
||||
), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
|
||||
bias = None
|
||||
if lora_config.lora_train_bias in ["all", "lora"]:
|
||||
bias = linear.bias
|
||||
if bias is None:
|
||||
bias = True
|
||||
lora_linear = LoraLinear(
|
||||
linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
|
||||
)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
|
||||
"""
|
||||
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to convert to LoRA form.
|
||||
lora_rank (int): The rank of the LoRA approximation.
|
||||
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
||||
parent_name (str): The name of the parent module.
|
||||
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
|
||||
if lora_config.target_modules is None or any(
|
||||
[name in target_module for target_module in lora_config.target_modules]
|
||||
):
|
||||
if dist.is_initialized() and dist.get_rank() == 0:
|
||||
logger.info(f"Converting {parent_name}.{name} to LoRA")
|
||||
setattr(module, name, _lora_linear_wrapper(child, lora_config))
|
||||
elif isinstance(child, nn.Embedding):
|
||||
if lora_config.target_modules is None or any(
|
||||
[name in target_module for target_module in lora_config.target_modules]
|
||||
):
|
||||
if dist.is_initialized() and dist.get_rank() == 0:
|
||||
logger.info(f"Converting {parent_name}.{name} to LoRA")
|
||||
setattr(
|
||||
module,
|
||||
name,
|
||||
LoraEmbedding(
|
||||
child.weight,
|
||||
r=lora_config.r,
|
||||
lora_alpha=lora_config.lora_alpha,
|
||||
lora_dropout=lora_config.embedding_lora_dropout,
|
||||
num_embeddings=child.num_embeddings,
|
||||
embedding_dim=child.embedding_dim,
|
||||
padding_idx=child.padding_idx,
|
||||
max_norm=child.max_norm,
|
||||
norm_type=child.norm_type,
|
||||
scale_grad_by_freq=child.scale_grad_by_freq,
|
||||
sparse=child.sparse,
|
||||
lora_initialization_method=lora_config.lora_initialization_method,
|
||||
),
|
||||
)
|
||||
else:
|
||||
_convert_to_lora_recursively(child, lora_rank)
|
||||
_convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
|
||||
|
||||
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
|
||||
def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
|
||||
"""Convert a torch.nn.Module to a LoRA module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to convert.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
||||
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
||||
|
||||
Returns:
|
||||
nn.Module: The converted module.
|
||||
"""
|
||||
if lora_rank <= 0:
|
||||
if lora_config.r <= 0:
|
||||
return module
|
||||
_convert_to_lora_recursively(module, lora_rank)
|
||||
lora.mark_only_lora_as_trainable(module, lora_train_bias)
|
||||
# make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
|
||||
total_parameter_size = 0
|
||||
for name, p in module.named_parameters():
|
||||
p.requires_grad = False
|
||||
if "bias" in name and lora_config.lora_train_bias == "all":
|
||||
p.requires_grad = True
|
||||
total_parameter_size += p.numel()
|
||||
_convert_to_lora_recursively(module, "", lora_config)
|
||||
trainable_parameter_size = 0
|
||||
for name, p in module.named_parameters():
|
||||
if p.requires_grad == True:
|
||||
trainable_parameter_size += p.numel()
|
||||
if dist.is_initialized() and dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
|
||||
)
|
||||
return module
|
||||
|
|
|
@ -5,6 +5,7 @@ loss functions
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import masked_mean
|
||||
|
@ -89,11 +90,22 @@ class DpoLoss(nn.Module):
|
|||
"""
|
||||
Dpo loss
|
||||
Details: https://arxiv.org/pdf/2305.18290.pdf
|
||||
|
||||
SimPO loss:
|
||||
Details: https://arxiv.org/pdf/2405.14734.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float = 0.1):
|
||||
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
gamma: The margin parameter in the SimPO paper.
|
||||
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
|
||||
Refer to the length normalization in the SimPO paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -104,7 +116,7 @@ class DpoLoss(nn.Module):
|
|||
chosen_mask: torch.Tensor,
|
||||
reject_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
||||
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
|
||||
|
||||
|
@ -113,6 +125,8 @@ class DpoLoss(nn.Module):
|
|||
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
|
||||
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
|
@ -127,13 +141,12 @@ class DpoLoss(nn.Module):
|
|||
if len(logprob_ref_chosen.shape) == 2:
|
||||
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
|
||||
else:
|
||||
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
|
||||
ref_logratios = logprob_ref_chosen - logprob_ref_reject
|
||||
else:
|
||||
# If no reference model is provided
|
||||
ref_logratios = 0.0
|
||||
|
||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||
logits = pi_logratios - ref_logratios
|
||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||
|
||||
# Calculate rewards for logging
|
||||
|
@ -168,3 +181,93 @@ class LogExpLoss(nn.Module):
|
|||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class OddsRatioLoss(nn.Module):
|
||||
"""
|
||||
Odds Ratio Loss in ORPO
|
||||
Details: https://arxiv.org/pdf/2403.07691
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logp: torch.Tensor,
|
||||
reject_logp: torch.Tensor,
|
||||
chosen_loss_mask: torch.Tensor,
|
||||
reject_loss_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
chosen_logp = chosen_logp.to(dtype=torch.float32)
|
||||
reject_logp = reject_logp.to(dtype=torch.float32)
|
||||
chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001)
|
||||
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
|
||||
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
|
||||
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
|
||||
log_odds_ratio = chosen_odds_masked - reject_odds_masked
|
||||
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
|
||||
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
|
||||
|
||||
|
||||
class KTOLoss(nn.Module):
|
||||
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the KTO paper.
|
||||
desirable_weight: The weight for the desirable responses.
|
||||
undesirable_weight: The weight for the undesirable
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.desirable_weight = desirable_weight
|
||||
self.undesirable_weight = undesirable_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logps: torch.Tensor,
|
||||
rejected_logps: torch.Tensor,
|
||||
kl_logps: torch.Tensor,
|
||||
ref_chosen_logps: torch.Tensor,
|
||||
ref_rejected_logps: torch.Tensor,
|
||||
ref_kl_logps: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Reference:
|
||||
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
|
||||
|
||||
Compute the KTO loss for a batch of policy and reference model log probabilities.
|
||||
Args:
|
||||
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
|
||||
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
desirable_weight: The weight for the desirable responses.
|
||||
undesirable_weight: The weight for the undesirable responses.
|
||||
|
||||
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
|
||||
"""
|
||||
kl = (kl_logps - ref_kl_logps).mean().detach()
|
||||
# all gather
|
||||
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
|
||||
kl = (kl / dist.get_world_size()).clamp(min=0)
|
||||
|
||||
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
|
||||
chosen_logratios = chosen_logps - ref_chosen_logps
|
||||
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
|
||||
chosen_rewards = self.beta * chosen_logratios.detach()
|
||||
else:
|
||||
chosen_losses = torch.Tensor([]).to(kl_logps.device)
|
||||
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
|
||||
rejected_logratios = rejected_logps - ref_rejected_logps
|
||||
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
|
||||
rejected_rewards = self.beta * rejected_logratios.detach()
|
||||
else:
|
||||
rejected_losses = torch.Tensor([]).to(kl_logps.device)
|
||||
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards, kl
|
||||
|
|
|
@ -89,7 +89,9 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
return mean
|
||||
|
||||
|
||||
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
def calc_masked_log_probs(
|
||||
logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor, length_normalization: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the masked log probabilities for a given sequence of logits.
|
||||
|
||||
|
@ -103,7 +105,11 @@ def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mas
|
|||
"""
|
||||
# logits are probabilities of the next token, so we shift them to the left by one
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs * mask
|
||||
|
||||
if not length_normalization:
|
||||
return log_probs * mask
|
||||
else:
|
||||
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
|
|
|
@ -1,7 +1,18 @@
|
|||
from .base import OLTrainer, SLTrainer
|
||||
from .dpo import DPOTrainer
|
||||
from .kto import KTOTrainer
|
||||
from .orpo import ORPOTrainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"]
|
||||
__all__ = [
|
||||
"SLTrainer",
|
||||
"OLTrainer",
|
||||
"RewardModelTrainer",
|
||||
"SFTTrainer",
|
||||
"PPOTrainer",
|
||||
"DPOTrainer",
|
||||
"ORPOTrainer",
|
||||
"KTOTrainer",
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Dpo trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
@ -25,7 +26,7 @@ from .utils import is_rank_0, to_device
|
|||
|
||||
class DPOTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
Trainer for DPO algorithm.
|
||||
|
||||
Args:
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
|
@ -53,6 +54,8 @@ class DPOTrainer(SLTrainer):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
beta: float = 0.1,
|
||||
gamma: float = 0.0,
|
||||
length_normalization: bool = False,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
|
@ -63,7 +66,7 @@ class DPOTrainer(SLTrainer):
|
|||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_loss_fn = DpoLoss(beta)
|
||||
self.actor_loss_fn = DpoLoss(beta, gamma)
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
|
@ -71,6 +74,7 @@ class DPOTrainer(SLTrainer):
|
|||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
self.length_normalization = length_normalization
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
|
@ -131,18 +135,21 @@ class DPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
reject_loss_mask[:, -1] = False
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
|
@ -150,14 +157,14 @@ class DPOTrainer(SLTrainer):
|
|||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
|
@ -219,7 +226,7 @@ class DPOTrainer(SLTrainer):
|
|||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if (self.num_train_step + 1) % self.save_interval == 0:
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
|
@ -283,16 +290,16 @@ class DPOTrainer(SLTrainer):
|
|||
actor_all_logits = self.model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
self.ref_model.eval()
|
||||
|
@ -300,11 +307,15 @@ class DPOTrainer(SLTrainer):
|
|||
ref_all_logits = self.ref_model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
logprob_actor_chosen,
|
||||
|
@ -314,7 +325,7 @@ class DPOTrainer(SLTrainer):
|
|||
chosen_loss_mask[:, 1:],
|
||||
reject_loss_mask[:, 1:],
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||
loss = losses.mean()
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
|
@ -333,4 +344,7 @@ class DPOTrainer(SLTrainer):
|
|||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
"""
|
||||
KTO trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from coati.models.loss import KTOLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import SLTrainer
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class KTOTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer for KTO algorithm.
|
||||
|
||||
Args:
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
ref_model (Critic): the reference model in ppo algorithm
|
||||
booster (Strategy): the strategy to use for training
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
|
||||
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
|
||||
max_epochs (int, defaults to 1): the max number of epochs to train
|
||||
accumulation_steps (int): the number of steps to accumulate gradients
|
||||
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
|
||||
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
|
||||
save_dir (str): the directory to save checkpoints
|
||||
coordinator (DistCoordinator): the coordinator to use for distributed logging
|
||||
beta (float, defaults to 0.1): the beta parameter in kto loss
|
||||
desirable_weight (float, defaults to 1.0): the weight for desirable reward
|
||||
undesirable_weight (float, defaults to 1.0): the weight for undesirable reward
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: Any,
|
||||
ref_model: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
beta: float = 0.1,
|
||||
desirable_weight: float = 1.0,
|
||||
undesirable_weight: float = 1.0,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
self.desirable_weight = desirable_weight
|
||||
self.undesirable_weight = undesirable_weight
|
||||
self.beta = beta
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
train_preference_dataloader: DataLoader = None,
|
||||
eval_preference_dataloader: DataLoader = None,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
"""
|
||||
self.train_dataloader = train_preference_dataloader
|
||||
self.eval_dataloader = eval_preference_dataloader
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
self.wandb_run = wandb.init(project="Coati-kto", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "kto")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def _train(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
self.model.train()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["loss_mask"],
|
||||
batch["label"],
|
||||
batch["kl_input_ids"],
|
||||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
with torch.no_grad():
|
||||
# calculate KL term with KT data
|
||||
kl_logits = self.model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logits = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
chosen_index = [i for i in range(batch_size) if label[i] == 1]
|
||||
rejected_index = [i for i in range(batch_size) if label[i] == 0]
|
||||
chosen_logprob = logprob[chosen_index]
|
||||
rejected_logprob = logprob[rejected_index]
|
||||
with torch.no_grad():
|
||||
ref_kl_logits = self.ref_model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
ref_logits = self.ref_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
ref_chosen_logprob = ref_logprob[chosen_index]
|
||||
ref_rejected_logprob = ref_logprob[rejected_index]
|
||||
|
||||
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
|
||||
)
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||
|
||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.num_train_step += 1
|
||||
step_bar.update()
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=i + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["loss_mask"],
|
||||
batch["label"],
|
||||
batch["kl_input_ids"],
|
||||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
with torch.no_grad():
|
||||
# calculate KL term with KT data
|
||||
kl_logits = self.model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logits = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
chosen_index = [i for i in range(batch_size) if label[i] == 1]
|
||||
rejected_index = [i for i in range(batch_size) if label[i] == 0]
|
||||
chosen_logprob = logprob[chosen_index]
|
||||
rejected_logprob = logprob[rejected_index]
|
||||
with torch.no_grad():
|
||||
ref_kl_logits = self.ref_model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
|
||||
ref_logits = self.ref_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
ref_chosen_logprob = ref_logprob[chosen_index]
|
||||
ref_rejected_logprob = ref_logprob[rejected_index]
|
||||
|
||||
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
|
||||
)
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||
self.accumulative_meter.add(
|
||||
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
||||
)
|
||||
step_bar.update()
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "margin"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
|
@ -0,0 +1,314 @@
|
|||
"""
|
||||
Orpo trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from coati.models.loss import OddsRatioLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import SLTrainer
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class ORPOTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer for ORPO algorithm.
|
||||
|
||||
Args:
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
booster (Strategy): the strategy to use for training
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
|
||||
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
|
||||
max_epochs (int, defaults to 1): the max number of epochs to train
|
||||
lam (float, defaults to 0.1): the lambda parameter in ORPO loss
|
||||
accumulation_steps (int): the number of steps to accumulate gradients
|
||||
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
|
||||
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
|
||||
save_dir (str): the directory to save checkpoints
|
||||
coordinator (DistCoordinator): the coordinator to use for distributed logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
lam: float = 0.1,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.lam = lam
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
train_preference_dataloader: DataLoader = None,
|
||||
eval_preference_dataloader: DataLoader = None,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
"""
|
||||
self.train_dataloader = train_preference_dataloader
|
||||
self.eval_dataloader = eval_preference_dataloader
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "orpo")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def _train(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
self.model.train()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
labels=torch.cat(
|
||||
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
|
||||
),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||
chosen_nll = actor_out["loss"]
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
|
||||
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
|
||||
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.num_train_step += 1
|
||||
step_bar.update()
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log_odds_ratio",
|
||||
self.accumulative_meter.get("log_odds_ratio"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=i + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.coordinator.print_on_master("\nStart evaluation...")
|
||||
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(self.eval_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
labels=torch.cat(
|
||||
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
|
||||
),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
)
|
||||
chosen_nll = actor_out["loss"]
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
|
||||
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
|
||||
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
|
@ -237,6 +237,7 @@ class RewardModelTrainer(SLTrainer):
|
|||
+ f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n"
|
||||
)
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
|
|
@ -102,6 +102,7 @@ class SFTTrainer(SLTrainer):
|
|||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
loss = outputs.loss
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
|
@ -113,6 +114,7 @@ class SFTTrainer(SLTrainer):
|
|||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||
|
@ -165,6 +167,7 @@ class SFTTrainer(SLTrainer):
|
|||
for tag in ["loss"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
151645,
|
||||
151643
|
||||
],
|
||||
"end_of_assistant": "<|im_end|>"
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
2
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
|
@ -9,6 +9,7 @@
|
|||
- [Install Requirements](#install-requirements)
|
||||
- [Get Start with ColossalRun](#get-start-with-colossalrun)
|
||||
- [Training Configuration](#training-configuration)
|
||||
- [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft)
|
||||
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
|
||||
- [Step 1: Data Collection](#step-1-data-collection)
|
||||
- [Step 2: Preprocessing](#step-2-preprocessing)
|
||||
|
@ -29,6 +30,9 @@
|
|||
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
||||
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
||||
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
|
||||
- [List of Supported Models](#list-of-supported-models)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Inference example](#inference-example)
|
||||
|
@ -45,9 +49,6 @@
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Get Start with ColossalRun
|
||||
|
||||
|
||||
|
@ -81,8 +82,6 @@ Make sure the master node can access all nodes (including itself) by ssh without
|
|||
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more details regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
|
||||
|
||||
|
||||
|
||||
|
||||
<details><summary><b>Gemini (Zero3)</b></summary>
|
||||
|
||||
|
||||
|
@ -374,35 +373,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Low Rank Adaption</b></summary>
|
||||
|
||||
|
||||
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
|
||||
|
||||
|
||||
To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
|
||||
```
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_interval 5000 \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin zero2_cpu \
|
||||
--batch_size 4 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--lr 2e-5 \
|
||||
--max_len 2048 \
|
||||
--lora_rank 32 \ # This enables LoRA
|
||||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary><b>Other Training Arguments</b></summary>
|
||||
|
||||
|
||||
|
@ -427,6 +397,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
|||
- use_wandb: if this flag is up, you can view logs on wandb.
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
### Parameter Efficient Finetuning (PEFT)
|
||||
|
||||
Currently, we have support LoRA (low-rank adaptation) and PiSSA (principal singular values and singular vectors adaptation). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance.
|
||||
|
||||
|
||||
<details><summary><b>Low Rank Adaption and PiSSA</b></summary>
|
||||
|
||||
|
||||
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources.
|
||||
|
||||
To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"r": 128,
|
||||
"embedding_lora_dropout": 0.0,
|
||||
"linear_lora_dropout": 0.1,
|
||||
"lora_alpha": 32,
|
||||
"lora_train_bias": "all",
|
||||
"lora_initialization_method": "PiSSA",
|
||||
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
|
||||
}
|
||||
```
|
||||
#### Lora Parameters
|
||||
- r: lora rank
|
||||
- embedding_lora_dropout: dropout probability for embedding layer
|
||||
- linear_lora_dropout: dropout probability for linear layer
|
||||
- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model.
|
||||
- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases)
|
||||
- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA.
|
||||
- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it.
|
||||
|
||||
|
||||
```
|
||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_interval 5000 \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin zero2_cpu \
|
||||
--batch_size 4 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--lr 2e-5 \
|
||||
--max_len 2048 \
|
||||
--lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA
|
||||
--use_wandb
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
@ -445,7 +469,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
|
|||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
},
|
||||
{
|
||||
|
@ -470,9 +494,15 @@ In this code we provide a flexible way for users to set the conversation templat
|
|||
- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields.
|
||||
```json
|
||||
{
|
||||
"chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
|
||||
"system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
|
||||
"end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format,
|
||||
"chat_template": "A string of chat_template used for formatting chat data",
|
||||
"system_message": "A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added",
|
||||
"end_of_assistant": "The token(s) in string that denotes the end of assistance's response",
|
||||
"stop_ids": "A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training"
|
||||
}
|
||||
```
|
||||
* `chat_template`: (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating.
|
||||
* `system_message`: A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added.
|
||||
* `end_of_assistant`: The token(s) in string that denotes the end of assistance's response". For example, in the ChatGLM2 prompt format,
|
||||
```
|
||||
<|im_start|>system
|
||||
system messages
|
||||
|
@ -481,15 +511,13 @@ In this code we provide a flexible way for users to set the conversation templat
|
|||
<|im_start|>user
|
||||
How far is the moon? <|im_end|>
|
||||
<|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>...
|
||||
```
|
||||
the end_of_assistant tokens are "<|im_end|>"
|
||||
"stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically
|
||||
}
|
||||
```
|
||||
On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
|
||||
```
|
||||
the `end_of_assistant` tokens are "<|im_end|>"
|
||||
* `stop_ids`: (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically.
|
||||
|
||||
On your first run of the data preparation script, you only need to define the `chat_template` (if you want to use custom chat template) and the `system message` (if you want to use a custom system message)
|
||||
|
||||
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
|
||||
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
|
||||
|
||||
|
||||
- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files.
|
||||
|
@ -509,7 +537,7 @@ Human: <s> what are some pranks with a pen i can do?</s> Assistant: <s> Are you
|
|||
|
||||
|
||||
#### Step 3: Training
|
||||
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
### RLHF Training Stage2 - Training Reward Model
|
||||
|
@ -526,7 +554,7 @@ Below shows the preference dataset format used in training the reward model.
|
|||
[
|
||||
{"context": [
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "Introduce butterflies species in Oregon."
|
||||
}
|
||||
]
|
||||
|
@ -551,11 +579,11 @@ Below shows the preference dataset format used in training the reward model.
|
|||
|
||||
|
||||
#### Step 2: Preprocessing
|
||||
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
|
||||
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
|
||||
|
||||
|
||||
#### Step 3: Training
|
||||
You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
You can run [train_rm.sh](./training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
#### Features and Tricks in RM Training
|
||||
|
@ -595,7 +623,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
|
|||
|
||||
|
||||
#### Step 1: Data Collection
|
||||
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
||||
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
||||
|
||||
|
||||
```json
|
||||
|
@ -603,7 +631,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
|
|||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
}
|
||||
...
|
||||
|
@ -626,14 +654,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to
|
|||
]
|
||||
```
|
||||
#### Step 2: Preprocessing
|
||||
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
|
||||
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./data_preparation_scripts/prepare_prompt_dataset.sh)
|
||||
|
||||
|
||||
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
|
||||
|
||||
|
||||
#### Step 3: Training
|
||||
You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
You can run the [train_ppo.sh](./training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
|
||||
|
||||
```bash
|
||||
|
@ -717,17 +745,90 @@ For DPO training, you only need the preference dataset. Please follow the instru
|
|||
|
||||
|
||||
#### Step 2: Training
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
You can run the [train_dpo.sh](./training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
|
||||
|
||||
```
|
||||
--beta 0.1 \ # the temperature in DPO loss, Default to 0.1
|
||||
--gamma 0.0 \ # the reward target margin in the SimPO paper, Default to 0.
|
||||
--disable_reference_model \ # whether to disable the reference model, if set, the implicit reward will be calculated solely from the actor. Default to enable reference model in DPO
|
||||
--length_normalization \ # whether to apply length normalization, Default to not use
|
||||
```
|
||||
|
||||
#### DPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/DPO.png">
|
||||
</p>
|
||||
|
||||
### Alternative Option For RLHF: Simple Preference Optimization
|
||||
|
||||
We support the method introduced in the paper [SimPO: Simple Preference Optimization
|
||||
with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
|
||||
|
||||
#### SimPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/SimPO_margin.png">
|
||||
</p>
|
||||
|
||||
|
||||
### Alternative Option For RLHF: Odds Ratio Preference Optimization
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
|
||||
#### ORPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
|
||||
</p>
|
||||
|
||||
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results.
|
||||
|
||||
For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examples/data_preparation_scripts/prepare_kto_dataset.sh). You will need preference data, different from DPO and its derivatives, you no longer need a pair of chosen/rejected response for the same input. You only need data whose response is associated with a preference label--- whether the response is okay or not, read the papre for more details. You also need to convert your data to the following intermediate format before you run the data preparation script.
|
||||
|
||||
```jsonl
|
||||
{
|
||||
"prompt": [
|
||||
{
|
||||
"from": "user",
|
||||
"content": "What are some praise words in english?"
|
||||
},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."
|
||||
},
|
||||
{
|
||||
"from": "user",
|
||||
"content": "What's your favorite one?"
|
||||
}
|
||||
],
|
||||
"completion": {
|
||||
"from": "assistant",
|
||||
"content": "impressive."
|
||||
},
|
||||
"label": true
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
|
||||
|
||||
#### KTO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
|
||||
</p>
|
||||
|
||||
## Hardware Requirements
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM.
|
||||
|
||||
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=72390.95 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=19412.77 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=43446.31 MB
|
||||
- zero2, micro batch size=16, VRAM Usage=58082.30 MB
|
||||
- zero2, micro batch size=8, lora_rank=8, VRAM Usage=21167.73 MB
|
||||
- zero2, micro batch size=8, lora_rank=32, VRAM Usage=21344.17 MB
|
||||
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model (llama2-7B-hf) on a dummy dataset with a sequence length of 2048 and a layout length of 512 with different tp_size (equal to the number of GPUs).
|
||||
| PPO | tp=8 | tp=4 |
|
||||
|-------|---------------|---------------|
|
||||
| bs=1 | 18485.19 MB | 42934.45 MB |
|
||||
|
@ -738,12 +839,39 @@ For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM
|
|||
|
||||
For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
|
||||
- 1 H800 GPU
|
||||
- zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
|
||||
- zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=2, VRAM Usage=36989.37 MB
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=48081.67 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2, batch size=4, VRAM Usage=67544.47 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=67483.44 MB
|
||||
|
||||
For SimPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM 25705.26 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=73375.04 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=36709.36 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=44330.90 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=56086.12 MB
|
||||
|
||||
For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM 26693.38 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=74332.65 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=38709.73 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
|
||||
|
||||
For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=2, VRAM Usage=35241.98 MB
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=38989.37 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
|
||||
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
|
||||
|
||||
## List of Supported Models
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ import random
|
|||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
@ -56,8 +56,8 @@ def main():
|
|||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
choices=["sft", "prompt", "preference"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
|
||||
choices=["sft", "prompt", "preference", "kto"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_input_dirs",
|
||||
|
@ -199,11 +199,13 @@ def main():
|
|||
)
|
||||
|
||||
if args.type == "sft":
|
||||
preparation_function = supervised_tokenize_sft
|
||||
preparation_function = tokenize_sft
|
||||
elif args.type == "prompt":
|
||||
preparation_function = tokenize_prompt_dataset
|
||||
preparation_function = tokenize_prompt
|
||||
elif args.type == "preference":
|
||||
preparation_function = tokenize_rlhf
|
||||
elif args.type == "kto":
|
||||
preparation_function = tokenize_kto
|
||||
else:
|
||||
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
|
||||
|
||||
|
@ -228,10 +230,13 @@ def main():
|
|||
keep_in_memory=False,
|
||||
num_proc=min(len(dataset), cpu_count()),
|
||||
)
|
||||
|
||||
dataset = dataset.filter(
|
||||
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
|
||||
)
|
||||
if args.type == "kto":
|
||||
filter_by = "completion"
|
||||
elif args.type == "preference":
|
||||
filter_by = "chosen_input_ids"
|
||||
else:
|
||||
filter_by = "input_ids"
|
||||
dataset = dataset.filter(lambda data: data[filter_by] is not None)
|
||||
|
||||
# Save each jsonl spliced dataset.
|
||||
output_index = "0" * (5 - len(str(index))) + str(index)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
SAVE_DIR=""
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type kto \
|
||||
--data_input_dirs /PATH/TO/KTO/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 1024
|
|
@ -5,9 +5,10 @@ rm -rf $SAVE_DIR/jsonl
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type preference \
|
||||
--data_input_dirs "PATH/TO/PREFERENCE/DATA" \
|
||||
--data_input_dirs /PATH/TO/PREFERENCE/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 1024
|
||||
|
|
|
@ -10,4 +10,5 @@ python prepare_dataset.py --type prompt \
|
|||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 1024
|
||||
|
|
|
@ -5,9 +5,10 @@ rm -rf $SAVE_DIR/jsonl
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs "PATH/TO/SFT/DATA" \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 4096
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
|
||||
|
||||
==========
|
|
@ -1,4 +1,4 @@
|
|||
pandas>=1.4.1
|
||||
sentencepiece
|
||||
colossalai
|
||||
colossalai==0.4.0
|
||||
prompt_toolkit
|
||||
|
|
|
@ -1,5 +1 @@
|
|||
XXX.XX.XXX.XXX # Your master IP
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
localhost
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"r": 128,
|
||||
"embedding_lora_dropout": 0.0,
|
||||
"linear_lora_dropout": 0.1,
|
||||
"lora_alpha": 32,
|
||||
"lora_train_bias": "all",
|
||||
"lora_initialization_method": "PiSSA",
|
||||
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
|
||||
}
|
|
@ -6,7 +6,7 @@ from contextlib import nullcontext
|
|||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import DPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
@ -23,8 +23,11 @@ logger = get_dist_logger()
|
|||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
@ -115,8 +118,8 @@ def train(args):
|
|||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.enable_reference_model:
|
||||
|
||||
if not args.disable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
|
@ -125,18 +128,20 @@ def train(args):
|
|||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
disable_dropout(model)
|
||||
disable_dropout(ref_model)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
|
@ -178,6 +183,21 @@ def train(args):
|
|||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
|
@ -255,25 +275,28 @@ def train(args):
|
|||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
length_normalization=args.length_normalization,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
if args.save_dir is not None:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
@ -296,6 +319,10 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss")
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
|
||||
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
|
||||
parser.add_argument("--length_normalization", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
|
@ -304,33 +331,39 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default=None)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--enable_reference_model", type=bool, default=True)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
if args.loss_type == "simpo_loss":
|
||||
args.length_normalization = True
|
||||
args.gamma = args.gamma if args.gamma > 0 else 1.4
|
||||
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
|
|
|
@ -13,50 +13,52 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
# export CUDA_VISIBLE_DEVICES=6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PROJECT_NAME="DPO"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
YOUR/DATA/DIR/arrow/part-00000
|
||||
YOUR/DATA/DIR/arrow/part-00001
|
||||
YOUR/DATA/DIR/arrow/part-00002
|
||||
YOUR/DATA/DIR/arrow/part-00003
|
||||
YOUR/DATA/DIR/arrow/part-00004
|
||||
YOUR/DATA/DIR/arrow/part-00005
|
||||
YOUR/DATA/DIR/arrow/part-00006
|
||||
YOUR/DATA/DIR/arrow/part-00007
|
||||
YOUR/DATA/DIR/arrow/part-00008
|
||||
YOUR/DATA/DIR/arrow/part-00009
|
||||
/Your/Preference/Data/arrow/part-00000
|
||||
/Your/Preference/Data/arrow/part-00001
|
||||
/Your/Preference/Data/arrow/part-00002
|
||||
/Your/Preference/Data/arrow/part-00003
|
||||
/Your/Preference/Data/arrow/part-00004
|
||||
/Your/Preference/Data/arrow/part-00005
|
||||
/Your/Preference/Data/arrow/part-00006
|
||||
/Your/Preference/Data/arrow/part-00007
|
||||
/Your/Preference/Data/arrow/part-00008
|
||||
/Your/Preference/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \
|
||||
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--batch_size 2 \
|
||||
--accumulation_steps 2 \
|
||||
--batch_size 16 \
|
||||
--lr 1e-6 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 4096 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 100 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
||||
|
|
|
@ -0,0 +1,376 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import KTOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
disable_dropout(ref_model)
|
||||
disable_dropout(model)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
num_desirable = 0
|
||||
num_undesirable = 0
|
||||
for i in range(len(train_dataset)):
|
||||
if train_dataset[i]["label"]:
|
||||
num_desirable += 1
|
||||
else:
|
||||
num_undesirable += 1
|
||||
logger.info(f"Dataset Statistics:\nDesirable: {num_desirable}\nUndesirable: {num_undesirable}")
|
||||
|
||||
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
|
||||
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
|
||||
if actual_ratio < 1 or actual_ratio > 4 / 3:
|
||||
if not args.auto_weight:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase/decrease desirable weight or decrease/increase undesirable weight."
|
||||
)
|
||||
else:
|
||||
args.desirable_weight = args.desirable_weight / actual_ratio
|
||||
coordinator.print_on_master(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, auto weight is enabled, set desirable weight to {args.desirable_weight} and undesirable weight to {args.undesirable_weight}"
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
actor=model,
|
||||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
desirable_weight=args.desirable_weight,
|
||||
undesirable_weight=args.undesirable_weight,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_dir is not None:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
|
||||
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
|
||||
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default=None)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--auto_weight", default=False, action="store_true")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="kto"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
/Your/KTO/Data/arrow/part-00000
|
||||
/Your/KTO/Data/arrow/part-00001
|
||||
/Your/KTO/Data/arrow/part-00002
|
||||
/Your/KTO/Data/arrow/part-00003
|
||||
/Your/KTO/Data/arrow/part-00004
|
||||
/Your/KTO/Data/arrow/part-00005
|
||||
/Your/KTO/Data/arrow/part-00006
|
||||
/Your/KTO/Data/arrow/part-00007
|
||||
/Your/KTO/Data/arrow/part-00008
|
||||
/Your/KTO/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
|
||||
|
||||
colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
--auto_weight \
|
||||
--lr 1e-5 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 1024 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint
|
|
@ -0,0 +1,341 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import ORPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
disable_dropout(model)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
actor=model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
lam=args.lam,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_dir is not None:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default=None)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,64 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
PROJECT_NAME="ORPO"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
/Your/Preference/Data/arrow/part-00000
|
||||
/Your/Preference/Data/arrow/part-00001
|
||||
/Your/Preference/Data/arrow/part-00002
|
||||
/Your/Preference/Data/arrow/part-00003
|
||||
/Your/Preference/Data/arrow/part-00004
|
||||
/Your/Preference/Data/arrow/part-00005
|
||||
/Your/Preference/Data/arrow/part-00006
|
||||
/Your/Preference/Data/arrow/part-00007
|
||||
/Your/Preference/Data/arrow/part-00008
|
||||
/Your/Preference/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 16 \
|
||||
--lr 8e-6 \
|
||||
--lam 0.5 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 1024 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
|
@ -13,7 +13,7 @@ from coati.dataset import (
|
|||
load_tokenized_dataset,
|
||||
setup_conversation_template,
|
||||
)
|
||||
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
|
||||
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
@ -31,8 +31,11 @@ logger = get_dist_logger()
|
|||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
@ -81,20 +84,26 @@ def train(args):
|
|||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
reward_model = RewardModel(args.rm_pretrain)
|
||||
critic = Critic(args.rm_pretrain)
|
||||
|
||||
if args.lora_config is not None:
|
||||
actor = convert_to_lora_module(actor, lora_config=lora_config)
|
||||
critic = convert_to_lora_module(critic, lora_config=lora_config)
|
||||
for name, module in actor.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
for name, module in critic.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
lora_manager.able_to_merge = False
|
||||
|
||||
# Disable dropout
|
||||
disable_dropout(actor)
|
||||
disable_dropout(critic)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
actor.gradient_checkpointing_enable()
|
||||
critic.model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
|
@ -421,11 +430,9 @@ def train(args):
|
|||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
lora_manager.able_to_merge = True
|
||||
actor.eval()
|
||||
critic.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
|
@ -484,11 +491,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=16)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_train_bias", type=str, default="none")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=9e-6)
|
||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
|
|
|
@ -15,10 +15,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME="ppo"
|
||||
PROJECT_NAME="PPO"
|
||||
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
@ -54,7 +53,7 @@ declare -a ptx_dataset=(
|
|||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
|
|
|
@ -7,7 +7,7 @@ from contextlib import nullcontext
|
|||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
||||
from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -16,14 +16,20 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
@ -55,9 +61,11 @@ def train(args):
|
|||
args.pretrain,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
|
@ -119,11 +127,9 @@ def train(args):
|
|||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
|
||||
if args.grad_checkpoint:
|
||||
model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
|
@ -173,6 +179,22 @@ def train(args):
|
|||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
|
@ -253,21 +275,21 @@ def train(args):
|
|||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
if args.save_dir is not None:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
@ -297,33 +319,28 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default=None)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
|
|
|
@ -15,10 +15,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME="rm"
|
||||
PROJECT_NAME="RM"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
|
@ -38,17 +38,18 @@ declare -a dataset=(
|
|||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
|
|
|
@ -7,7 +7,7 @@ from contextlib import nullcontext
|
|||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.models import LoraConfig, convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
@ -24,8 +24,11 @@ logger = get_dist_logger()
|
|||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
@ -53,15 +56,19 @@ def train(args):
|
|||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
|
@ -114,6 +121,15 @@ def train(args):
|
|||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
|
@ -122,12 +138,10 @@ def train(args):
|
|||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
@ -151,15 +165,6 @@ def train(args):
|
|||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
@ -175,6 +180,23 @@ def train(args):
|
|||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
@ -202,6 +224,7 @@ def train(args):
|
|||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
|
@ -257,22 +280,21 @@ def train(args):
|
|||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=None,
|
||||
eval_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
|
||||
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
|
||||
if args.save_path is not None:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
@ -302,32 +324,27 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--save_path", type=str, default="output")
|
||||
parser.add_argument("--save_path", type=str, default=None)
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
|
|
|
@ -13,13 +13,11 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
PROJECT_NAME="sft"
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
PROJECT_NAME="SFT"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
|
@ -38,28 +36,25 @@ declare -a dataset=(
|
|||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
|
||||
|
||||
echo $(which colossalai)
|
||||
echo $(which python)
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
--save_interval 2000 \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin 3d \
|
||||
--tp 2 \
|
||||
--pp 1 \
|
||||
--zero_stage 0 \
|
||||
--batch_size 2 \
|
||||
--max_epochs 3 \
|
||||
--plugin zero2 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 5e-5 \
|
||||
--max_len 400 \
|
||||
--max_len 4096 \
|
||||
--use_flash_attn \
|
||||
--grad_checkpoint \
|
||||
--use_wandb \
|
||||
--use_flash_attn
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
transformers>=4.36.2
|
||||
transformers==4.39.3
|
||||
tqdm
|
||||
datasets==2.14.7
|
||||
loralib
|
||||
colossalai>=0.3.7
|
||||
torch>=1.12.1
|
||||
colossalai==0.4.0
|
||||
torch>=2.1.0
|
||||
langchain
|
||||
tokenizers
|
||||
fastapi
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
|
||||
sft_seed = {
|
||||
"messages": [
|
||||
{"from": "human", "content": "Give three tips for staying healthy."},
|
||||
{"from": "user", "content": "Give three tips for staying healthy."},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
|
||||
|
@ -13,7 +13,7 @@ sft_seed = {
|
|||
}
|
||||
prompt_seed = {
|
||||
"messages": [
|
||||
{"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
|
||||
{"from": "user", "content": "Describe the impacts of climate change on communities living in coastal areas."},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
|
||||
|
@ -22,21 +22,34 @@ prompt_seed = {
|
|||
}
|
||||
preference_seed = {
|
||||
"context": [
|
||||
{"from": "human", "content": "What kind of noises did dinosaurs make?"},
|
||||
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
|
||||
},
|
||||
{"from": "human", "content": "yes they did"},
|
||||
{"from": "user", "content": "yes they did"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
|
||||
},
|
||||
{"from": "human", "content": "you cant read"},
|
||||
{"from": "user", "content": "you cant read"},
|
||||
],
|
||||
"chosen": [{"from": "assistant", "content": "You can read?"}],
|
||||
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
|
||||
}
|
||||
kto_seed = {
|
||||
"prompt": [
|
||||
{"from": "user", "content": "What are some praise words in english?"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ...",
|
||||
},
|
||||
{"from": "user", "content": "What's your favorite one?"},
|
||||
],
|
||||
"completion": {"from": "assistant", "content": "Impressive."},
|
||||
"label": True,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -61,12 +74,21 @@ if __name__ == "__main__":
|
|||
seed = prompt_seed
|
||||
elif args.data_type == "preference":
|
||||
seed = preference_seed
|
||||
elif args.data_type == "kto":
|
||||
seed = kto_seed
|
||||
else:
|
||||
raise ValueError(f"Unknown data type {args.data_type}")
|
||||
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
if args.data_type != "kto":
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
f.write(line)
|
||||
f.write(line)
|
||||
f.write(line)
|
||||
else:
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
seed["label"] = not seed["label"]
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
f.write(line)
|
||||
|
|
|
@ -1 +1 @@
|
|||
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
|
||||
{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
{"prompt": [{"from": "user", "content": "What are some praise words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "impressive."},"label": true}
|
|
@ -1 +1 @@
|
|||
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
|
||||
{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
|
||||
|
|
|
@ -71,6 +71,8 @@ get_data_input_dirs() {
|
|||
echo "$PROMPT_DATASET"
|
||||
elif [[ $data_type == "preference" ]]; then
|
||||
echo "$PREFERENCE_DATASET"
|
||||
elif [[ $data_type == "kto" ]]; then
|
||||
echo "$KTO_DATASET"
|
||||
else
|
||||
echo "Unknown data type $data_type"
|
||||
exit 1
|
||||
|
@ -121,6 +123,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
|||
--data_dir $(get_data_input_dirs prompt) \
|
||||
--data_type "prompt"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs kto) \
|
||||
--data_type "kto"
|
||||
|
||||
echo "[Test]: testing prepare_preference_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
|
@ -258,3 +264,50 @@ for model in ${MODELS[@]}; do
|
|||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing prepare_kto_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_kto_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="kto"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type kto \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
|
||||
|
@ -38,7 +39,7 @@ def test_overfit():
|
|||
# Build and convert model
|
||||
model = SimpleNN(input_size, hidden_size, num_classes)
|
||||
weight_to_compare = model.fc1.weight.detach().clone()
|
||||
model = convert_to_lora_module(model, lora_rank=30)
|
||||
model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
|
||||
|
||||
# Loss and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
@ -50,7 +51,6 @@ def test_overfit():
|
|||
# Forward pass
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
print(loss)
|
||||
# Backward and optimize
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
@ -65,5 +65,50 @@ def test_overfit():
|
|||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||
|
||||
|
||||
def test_lora_linear_accuracy():
|
||||
|
||||
weight = torch.randn(10, 5)
|
||||
linear = nn.Linear(5, 10)
|
||||
linear.weight.data = weight
|
||||
x = torch.randn(10, 5)
|
||||
out_linear = linear(x)
|
||||
|
||||
# lora linear Pissa
|
||||
linear.weight.data = weight
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
|
||||
out_lora = lora_linear(x)
|
||||
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
# lora linear
|
||||
linear.weight.data = weight
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
|
||||
out_lora = lora_linear(x)
|
||||
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
|
||||
def test_lora_embedding_accuracy():
|
||||
weight = torch.randn(10, 5)
|
||||
embedding = nn.Embedding(10, 5)
|
||||
embedding.weight.data = weight
|
||||
x = torch.randint(0, 10, (10,))
|
||||
out_embedding = embedding(x)
|
||||
|
||||
# lora embedding Pissa
|
||||
embedding.weight.data = weight
|
||||
lora_embedding = LoraEmbedding(
|
||||
embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
|
||||
)
|
||||
out_lora = lora_embedding(x)
|
||||
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
# lora embedding
|
||||
embedding.weight.data = weight
|
||||
lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
|
||||
out_lora = lora_embedding(x)
|
||||
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_overfit()
|
||||
test_lora_linear_accuracy()
|
||||
test_lora_embedding_accuracy()
|
||||
|
|
|
@ -94,7 +94,7 @@ done
|
|||
|
||||
# Test DPO/PPO data Preparation
|
||||
for model in ${MODELS[@]}; do
|
||||
echo "Testing DPO/PPO data templating for $model"
|
||||
echo "Testing DPO/RM data templating for $model"
|
||||
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
|
@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
|
|||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the DPO data templating for $model"
|
||||
echo "[Test]: Failed in the DPO/RM data templating for $model"
|
||||
exit 1
|
||||
fi
|
||||
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
|
||||
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the DPO data templating test for $model"
|
||||
echo "[Test]: Failed in the DPO/RM data templating test for $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
# Test KTO data Preparation
|
||||
for model in ${MODELS[@]}; do
|
||||
echo "Testing KTO data templating for $model"
|
||||
SAVE_DIR=$DATA_SAVE_PATH/kto/$model
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
pretrain=$(get_pretrain $model)
|
||||
conversation_template_config=$(get_conversation_template_config $model)
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
|
||||
--tokenizer_dir $pretrain \
|
||||
--conversation_template_config $conversation_template_config \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the KTO data templating for $model"
|
||||
exit 1
|
||||
fi
|
||||
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
|
||||
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the KTO data templating test for $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -30,9 +30,10 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
|||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
|
||||
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
|
||||
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
|
||||
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
|
@ -112,6 +113,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
sp='1'
|
||||
sp_mode='split_gather'
|
||||
enable_sequence_parallelism=''
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
|
@ -173,9 +179,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
|
@ -192,8 +199,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -229,6 +236,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
|
@ -248,9 +260,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
|
@ -262,8 +275,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -306,6 +319,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
bs='4'
|
||||
ebs='8'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='16'
|
||||
|
@ -342,7 +360,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--lora_rank $lora_rank \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
|
@ -361,8 +379,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -402,6 +420,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
|
@ -423,9 +446,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
|
@ -437,8 +461,176 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing ORPO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto doesn't support generation
|
||||
# (need to calculate ref_model logits through forwarding in inference mode)
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing KTO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto doesn't support generation
|
||||
# (need to calculate ref_model logits through forwarding in inference mode)
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
--auto_weight \
|
||||
--desirable_weight 1.2 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -62,3 +62,11 @@ if __name__ == "__main__":
|
|||
assert any(
|
||||
[rejected_lable in s for s in to_verify_lable_rejected]
|
||||
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
|
||||
elif args.data_type == "kto":
|
||||
sample = data[0]
|
||||
to_verify_data = to_verify_data[0]
|
||||
for line in sample["prompt"]:
|
||||
assert line["content"] in to_verify_data["input_id_decode"]
|
||||
assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
|
||||
assert sample["completion"]["content"] in to_verify_data["completion_decode"]
|
||||
assert sample["label"] == to_verify_data["label"]
|
||||
|
|
|
@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = glob.glob(os.path.join(path, "*.jsonl"))
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from abc import abstractstaticmethod
|
||||
|
||||
from colossal_eval.utils import jdump
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
|
@ -12,13 +15,24 @@ class BaseDataset:
|
|||
logger: Logger for the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
|
||||
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
|
||||
def __init__(self, path, logger, *args, **kwargs):
|
||||
self.dataset = self.load(path, logger, *args, **kwargs)
|
||||
|
||||
def save(self, save_path):
|
||||
"""Save the converted dataset"""
|
||||
jdump(self.dataset, save_path)
|
||||
|
||||
@abstractstaticmethod
|
||||
def load(path, logger):
|
||||
def load(path, logger: DistributedLogger, *args, **kwargs):
|
||||
"""Load the original dataset and convert it into the inference dataset"""
|
||||
|
||||
|
||||
class DistributedDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
|
|
@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
|
|
|
@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
|
|
|
@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
data = jload(path)
|
||||
data_per_category = get_data_per_category(data)
|
||||
|
|
|
@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl")
|
||||
data_list = []
|
||||
|
|
|
@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
|
||||
files = os.listdir(os.path.join(path, "data", category))
|
||||
|
|
|
@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = os.listdir(path)
|
||||
|
|
|
@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
|
|
|
@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset):
|
|||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot):
|
||||
def __init__(self, path, logger: DistributedLogger, *args, **kwargs):
|
||||
self.multiturn = True
|
||||
self.dataset = self.load(path, logger, few_shot)
|
||||
self.dataset = self.load(path, logger, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"test": defaultdict(dict)}
|
||||
|
||||
file_path = os.path.join(path, "question.jsonl")
|
||||
|
|
|
@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
data_files = [os.path.join(path, file_name) for file_name in FILES]
|
||||
for file_path in data_files:
|
||||
|
|
|
@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
data_files = [os.path.join(path, file_name) for file_name in FILES]
|
||||
for file_path in data_files:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import copy
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
|
||||
from peft import PeftModel
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
@ -130,7 +130,7 @@ class HuggingFaceModel(BaseModel):
|
|||
if shard_config is not None:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model, _ = shard_former.optimize(self.model)
|
||||
self.model.to(get_current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
|
@ -325,7 +325,7 @@ class HuggingFaceModel(BaseModel):
|
|||
|
||||
return input_ids_list, labels_list, None
|
||||
|
||||
def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
|
||||
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
|
||||
"""
|
||||
Infer the given data.
|
||||
This function will call self.generate() to get model outputs and also self.model() to get logits.
|
||||
|
@ -359,26 +359,23 @@ class HuggingFaceModel(BaseModel):
|
|||
|
||||
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
|
||||
|
||||
turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
|
||||
turn_desc = "" if turn == 0 else f"-turn{turn}"
|
||||
|
||||
bar = tqdm(
|
||||
range(math.ceil(len(data) / self.batch_size)),
|
||||
desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
|
||||
range(len(data_loader)),
|
||||
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
answers = copy.deepcopy(data)
|
||||
for i in range(0, len(data), self.batch_size):
|
||||
batch = data[i : i + self.batch_size]
|
||||
answers = []
|
||||
|
||||
for i, batch in enumerate(data_loader):
|
||||
batch_prompt, batch_target = get_batch_prompt(
|
||||
self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length
|
||||
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
|
||||
)
|
||||
|
||||
if is_rank_0() and debug and i == 0:
|
||||
self.logger.info(
|
||||
f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}"
|
||||
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
|
||||
)
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info("An example prompt and prompt with target is:")
|
||||
|
@ -402,7 +399,7 @@ class HuggingFaceModel(BaseModel):
|
|||
# Otherwise this will violate the single-choice setting.
|
||||
|
||||
if calculate_loss:
|
||||
labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))]
|
||||
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
|
||||
|
||||
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
|
||||
|
||||
|
@ -411,29 +408,30 @@ class HuggingFaceModel(BaseModel):
|
|||
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
|
||||
]
|
||||
|
||||
for j in range(len(batch_prompt)):
|
||||
for j in range(len(batch)):
|
||||
if not pretrain:
|
||||
if isinstance(answers[i + j]["output"], list):
|
||||
answers[i + j]["output"].append(batch_decodes[j].strip())
|
||||
if isinstance(batch[j]["output"], list):
|
||||
batch[j]["output"].append(batch_decodes[j].strip())
|
||||
else:
|
||||
answers[i + j]["output"] = batch_decodes[j].strip()
|
||||
batch[j]["output"] = batch_decodes[j].strip()
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
answers[i + j]["logits_over_choices"] = probs[j]
|
||||
batch[j]["logits_over_choices"] = probs[j]
|
||||
|
||||
if calculate_loss:
|
||||
answers[i + j]["loss_over_choices"] = loss_over_choices[j]
|
||||
batch[j]["loss_over_choices"] = loss_over_choices[j]
|
||||
|
||||
if calculate_loss:
|
||||
answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
|
||||
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
|
||||
|
||||
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
|
||||
# However, loss (which is per sample loss) suffices for most cases.
|
||||
answers[i + j]["loss_sum"] = batch_losses[j]
|
||||
answers[i + j]["token_num"] = batch_target_token_nums[j]
|
||||
batch[j]["loss_sum"] = batch_losses[j]
|
||||
batch[j]["token_num"] = batch_target_token_nums[j]
|
||||
|
||||
if batch_bytes_nums:
|
||||
answers[i + j]["byte_num"] = batch_bytes_nums[j]
|
||||
batch[j]["byte_num"] = batch_bytes_nums[j]
|
||||
answers.extend(batch)
|
||||
|
||||
bar.update()
|
||||
|
||||
|
@ -600,7 +598,7 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
|||
if shard_config is not None:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model, _ = shard_former.optimize(self.model)
|
||||
self.model.to(get_current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
|
|
|
@ -123,15 +123,13 @@ class Conversation:
|
|||
}
|
||||
|
||||
|
||||
def get_few_shot_prefix(
|
||||
conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
|
||||
) -> str:
|
||||
def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str:
|
||||
"""
|
||||
Get few shot prefix.
|
||||
|
||||
Args:
|
||||
conv: Conversation template.
|
||||
few_shot_examples: Few shot examples to generate few shot prompt prefix.
|
||||
few_shot_data: Few shot examples to generate few shot prompt prefix.
|
||||
tokenizer: tokenizer used to tokenize data.
|
||||
|
||||
Returns:
|
||||
Few shot prompt prefix.
|
||||
|
@ -157,7 +155,6 @@ def get_batch_prompt(
|
|||
batch: List[Dict],
|
||||
few_shot_data: List[str],
|
||||
tokenizer: Optional[AutoTokenizer],
|
||||
language: Optional[str],
|
||||
model_max_length: Optional[int],
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
|
@ -167,6 +164,7 @@ def get_batch_prompt(
|
|||
conv: Conversation template.
|
||||
batch: Batch data to generate prompt from.
|
||||
few_shot_data: Few shot data to generate few shot prompt prefix.
|
||||
tokenizer: tokenizer used to tokenize data.
|
||||
|
||||
Returns:
|
||||
Tuple containg batch prompt and target.
|
||||
|
@ -192,7 +190,7 @@ def get_batch_prompt(
|
|||
else:
|
||||
raise Exception("When using few-shot, target answer should be a string.")
|
||||
|
||||
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
|
||||
few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens)
|
||||
|
||||
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
|
|
@ -5,6 +5,8 @@ from typing import Dict, List
|
|||
|
||||
import torch.distributed as dist
|
||||
from colossal_eval import dataset, models, utils
|
||||
from colossal_eval.dataset.base import DistributedDataset
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
@ -13,6 +15,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.shardformer import ShardConfig
|
||||
|
||||
logger = get_dist_logger()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def rm_and_merge(
|
||||
|
@ -54,7 +57,8 @@ def rm_and_merge(
|
|||
)
|
||||
else:
|
||||
rank_answers = utils.jload(directory)
|
||||
answers["data"].extend(rank_answers["data"])
|
||||
deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]]
|
||||
answers["data"].extend(deduplidate_answers)
|
||||
answers["inference_kwargs"] = rank_answers["inference_kwargs"]
|
||||
|
||||
for r in range(dp_size):
|
||||
|
@ -65,7 +69,7 @@ def rm_and_merge(
|
|||
os.remove(directory)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print(len(answers["data"]))
|
||||
all_answers[category] = answers
|
||||
|
||||
all_answers_with_dataset_class["inference_results"] = all_answers
|
||||
|
@ -108,7 +112,12 @@ def main(args):
|
|||
tp_rank = coordinates[TP_AXIS]
|
||||
|
||||
shard_config = (
|
||||
ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
|
||||
ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
enable_tensor_parallelism=args.tp_size > 1,
|
||||
parallel_output=False,
|
||||
enable_all_optimization=True,
|
||||
)
|
||||
if args.tp_size > 1
|
||||
else None
|
||||
)
|
||||
|
@ -183,6 +192,7 @@ def main(args):
|
|||
model_name = model_parameter["name"]
|
||||
model_class = eval(f"models.{model_parameter['model_class']}")
|
||||
paramerters = model_parameter["parameters"]
|
||||
batch_size = paramerters["batch_size"]
|
||||
paramerters.update({"logger": logger})
|
||||
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
|
||||
paramerters.update({"shard_config": shard_config})
|
||||
|
@ -192,7 +202,6 @@ def main(args):
|
|||
raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
|
||||
|
||||
for dataset_name, split_data in inference_data.items():
|
||||
start = 0
|
||||
prev_questions = None
|
||||
for category, category_data in split_data.items():
|
||||
num_turn = category_data["inference_kwargs"].get("turns", 1)
|
||||
|
@ -201,26 +210,33 @@ def main(args):
|
|||
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
|
||||
|
||||
answers_to_dump = copy.deepcopy(category_data)
|
||||
partition_size = len(category_data["data"]) // dp_size
|
||||
redundant = len(category_data["data"]) % dp_size
|
||||
|
||||
# Ensure that the amount of data for inference is as consistent as possible across different processes.
|
||||
lengths = [partition_size for _ in range(dp_size)]
|
||||
for j in range(redundant):
|
||||
lengths[(j + start) % dp_size] += 1
|
||||
|
||||
start = (start + redundant) % dp_size
|
||||
|
||||
for turn in range(num_turn):
|
||||
if turn == 0:
|
||||
questions = category_data["data"][
|
||||
sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
|
||||
]
|
||||
dist_dataset = DistributedDataset(category_data["data"])
|
||||
else:
|
||||
questions = prev_questions
|
||||
dist_dataset = DistributedDataset(prev_questions)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
dist_dataset,
|
||||
num_replicas=pg_mesh.size(DP_AXIS),
|
||||
rank=pg_mesh.coordinate(DP_AXIS),
|
||||
shuffle=False,
|
||||
)
|
||||
questions_loader = DataLoader(
|
||||
dist_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
category_data["inference_kwargs"]["dataset"] = dataset_name
|
||||
category_data["inference_kwargs"]["category"] = category
|
||||
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
data_loader=questions_loader,
|
||||
inference_kwargs=category_data["inference_kwargs"],
|
||||
debug=debug_args[dataset_name],
|
||||
)
|
||||
prev_questions = answers_per_rank
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import ctypes
|
|||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
|
@ -30,11 +30,15 @@ from colossalai.interface.optimizer import DistributedOptim
|
|||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
|
@ -61,6 +65,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
use_ddp: bool,
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
|
@ -69,6 +74,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
self.sp_group = sp_group
|
||||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
|
@ -106,6 +112,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=True)
|
||||
|
||||
def sync_shared_params(self):
|
||||
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
|
||||
|
@ -197,7 +209,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
with self._wait_all_gather():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
|
@ -205,6 +218,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
module = module.module
|
||||
return module
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def _wait_all_gather(self):
|
||||
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
|
@ -235,7 +255,7 @@ def get_param_info(optim: Optimizer):
|
|||
return param_info
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
def reinitialize_optimizer(optim: Optimizer, model: Module):
|
||||
model_params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
|
@ -257,7 +277,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
reinitialize_optimizer(optim, model)
|
||||
self.model = model
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
|
@ -478,7 +498,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
reinitialize_optimizer(optim, model)
|
||||
super().__init__(
|
||||
optim,
|
||||
precision=precision,
|
||||
|
@ -632,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
|
@ -650,6 +671,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
|
@ -658,11 +680,12 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
pg_to_param_list=pg_to_param_list,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
|
@ -677,6 +700,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
def sync_dp_grads(self):
|
||||
|
@ -993,9 +1017,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
@ -1038,17 +1064,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
if dp_outside:
|
||||
(
|
||||
self.dp_axis,
|
||||
self.pp_axis,
|
||||
self.tp_axis,
|
||||
self.sp_axis,
|
||||
) = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
|
@ -1148,6 +1164,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
@ -1176,7 +1193,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
@ -1210,6 +1227,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
|
@ -1224,6 +1242,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
|
@ -1306,7 +1325,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
# so we disable it, performing manual reduction instead.
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
with ctx:
|
||||
with ctx, model._wait_all_gather():
|
||||
outputs = self.schedule.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
@ -1362,15 +1381,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
Returns:`
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset,
|
||||
num_replicas=self.pg_mesh.size(self.dp_axis),
|
||||
rank=self.pg_mesh.coordinate(self.dp_axis),
|
||||
num_replicas=self.dp_group.size(),
|
||||
rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
|
@ -1402,6 +1421,24 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
self,
|
||||
model: Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: Optional[Dict] = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
) -> Module:
|
||||
raise NotImplementedError
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||
assert self.pp_size == 1 and self.tp_size == 1
|
||||
self.lora_enabled = True
|
||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
|
||||
if pretrained_dir is None:
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
return peft_model
|
||||
|
|
|
@ -2,6 +2,7 @@ import enum
|
|||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
|
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
|
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
|
|||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == "fp16":
|
||||
|
@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
with ctx:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
|
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
load_sub_module: bool = True,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
|
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
|
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
reduce_bucket_size_in_m: int = 12,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
|
@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
partition_grad=(stage == 2),
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
|
@ -406,7 +448,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||
warnings.warn(
|
||||
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||
)
|
||||
elif (
|
||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
|
@ -433,7 +475,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
model = LowLevelZeroModel(
|
||||
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
zero_stage = self.stage
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
from typing import Callable, List, Optional, OrderedDict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -11,34 +10,42 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
PRECISION_TORCH_TYPE,
|
||||
SUPPORT_SP_MODE,
|
||||
HybridParallelAMPOptimizer,
|
||||
HybridParallelModule,
|
||||
HybridParallelNaiveOptimizer,
|
||||
HybridParallelPlugin,
|
||||
HybridParallelZeroOptimizer,
|
||||
get_param_info,
|
||||
init_pipeline_optimizer,
|
||||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup], # if using tp
|
||||
pp_process_group: Optional[ProcessGroup], # if using pp
|
||||
moe_dp_group: ProcessGroup, # moe dp pg for comm
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
|
@ -51,37 +58,25 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: [],
|
||||
moe_extra_dp_process_group: [],
|
||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
}
|
||||
for param in model.parameters():
|
||||
if is_moe_tensor(param):
|
||||
pg_param_list[moe_extra_dp_process_group].append(param)
|
||||
else:
|
||||
pg_param_list[dp_process_group].append(param)
|
||||
|
||||
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
|
||||
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
pg_to_param_list=pg_param_list,
|
||||
use_pipeline=use_pipeline,
|
||||
param_info=param_info,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
@ -96,30 +91,37 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
tp_process_group=tp_process_group,
|
||||
pp_process_group=pp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
pg_to_param_list=pg_param_list,
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
|
||||
class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
"""
|
||||
Plugin for Moe Hybrid Parallel Training.
|
||||
Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin
|
||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||
|
||||
Example:
|
||||
>>> from colossalai.booster import Booster
|
||||
>>> from colossalai.booster.plugin import HybridParallelPlugin
|
||||
```python
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||
|
||||
>>> model, train_dataset, optimizer, criterion = ...
|
||||
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
||||
model, train_dataset, optimizer, criterion = ...
|
||||
plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2)
|
||||
|
||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
>>> booster = Booster(plugin=plugin)
|
||||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
```
|
||||
|
||||
Args:
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
ep_size (int): The size of expert parallelism
|
||||
sp_size (int): The size of sequence parallelism.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
|
@ -132,7 +134,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
|
@ -155,15 +159,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
||||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
tp_size: int = 1,
|
||||
sp_size: int = 1,
|
||||
sp_size: int = None,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
|
@ -171,7 +181,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
|
@ -191,27 +203,61 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
overlap_communication: bool = False,
|
||||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpointIO] = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
moe_dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
) -> None:
|
||||
world_size = dist.get_world_size()
|
||||
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
|
||||
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
|
||||
if overlap_communication or zero_stage == 2:
|
||||
overlap_communication = False
|
||||
zero_stage = 1
|
||||
warnings.warn(
|
||||
f"overlap_communication and zero_stage are set to False and 1 because "
|
||||
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
|
||||
)
|
||||
|
||||
assert (
|
||||
world_size % (tp_size * pp_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
world_size % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
if enable_sequence_parallelism:
|
||||
self.sequence_parallelism_mode = (
|
||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||
)
|
||||
assert (
|
||||
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
assert (
|
||||
tp_size > 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
||||
if sp_size != 1:
|
||||
warnings.warn(
|
||||
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
|
||||
)
|
||||
self.sp_size = 1
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||
self.sp_size = 1 if sp_size is None else sp_size
|
||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
|
||||
else:
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
assert (
|
||||
sp_size == 1 or sp_size is None
|
||||
), f"You should not set sp_size when sequence parallelism is not enabled."
|
||||
self.sp_size = 1
|
||||
|
||||
self.dp_size = world_size // (tp_size * pp_size)
|
||||
assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}"
|
||||
self.moe_dp_size = self.dp_size // ep_size
|
||||
self.ep_size = ep_size
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.ep_size = ep_size
|
||||
self.sp_size = sp_size
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
@ -220,61 +266,69 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
|
||||
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
assert (
|
||||
self.ep_size <= self.dp_size
|
||||
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
|
||||
|
||||
self.moe_dp_size = self.dp_size // self.ep_size
|
||||
self.use_ep_inside = use_ep_inside
|
||||
if self.use_ep_inside:
|
||||
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
|
||||
if moe_dp_outside:
|
||||
self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
|
||||
self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
|
||||
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
|
||||
self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||
|
||||
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
|
||||
logger.info(
|
||||
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
|
||||
)
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(
|
||||
self.tp_axis
|
||||
) # TODO: support custom tp size for mixtral lm head
|
||||
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
# TODO: Currently moe only support partially sequence parallel
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
|
||||
self.custom_policy = custom_policy
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
assert (
|
||||
self.zero_stage <= 1
|
||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||
self.schedule = InterleavedSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis])
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
ep_group=self.ep_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
|
@ -282,8 +336,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
ep_group=self.ep_group,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -310,77 +367,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=self.dp_size,
|
||||
rank=dist.get_rank(self.global_dp_group),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpointIO(
|
||||
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(
|
||||
self.global_dp_group,
|
||||
self.pp_group,
|
||||
self.tp_group,
|
||||
ep_group=self.ep_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
zero_stage=self.zero_stage,
|
||||
)
|
||||
if hasattr(self.checkpoint_io, "moe_info"):
|
||||
self.checkpoint_io.moe_info = self.moe_info
|
||||
return self.checkpoint_io
|
||||
return MoECheckpointIO(
|
||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
|
@ -391,13 +387,40 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
# Replace with distributed implementation if exists
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if use_ddp:
|
||||
warnings.warn(
|
||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
||||
)
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
raise ValueError(
|
||||
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
|
||||
)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.global_dp_group,
|
||||
dp_group=dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
|
@ -405,7 +428,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
custom_policy=self.custom_policy,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
||||
# but the optimizer is not aware of ep, so we need to update the optimizer
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
|
@ -418,20 +447,30 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
if self.dp_size <= 1:
|
||||
warnings.warn(
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
||||
)
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.global_dp_group,
|
||||
dp_process_group=dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_extra_dp_process_group=self.moe_dp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
|
@ -440,4 +479,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
# Setup optimizers that require global states
|
||||
optim = optimizer.optim
|
||||
if isinstance(optim, DistributedOptim):
|
||||
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
|
||||
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
|
||||
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
|
@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
model_before_wrapping = model # backup for model before wrapping
|
||||
model = model.unwrap()
|
||||
|
||||
|
@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
|
||||
if self.dp_rank != 0:
|
||||
|
@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
strict = False
|
||||
model_before_wrapping = model
|
||||
model = model.unwrap()
|
||||
|
@ -943,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||
|
|
|
@ -151,13 +151,10 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
# ep_rank 0 saves all the parameters and buffers.
|
||||
# other ep_ranks save only experts
|
||||
ep_param_pattern = "experts." if self.ep_rank != 0 else None
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MoECheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
|
|
@ -44,7 +44,7 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||
self._rank = dist.get_rank()
|
||||
self._world_size = dist.get_world_size()
|
||||
# this is often passed by launchers such as torchrun
|
||||
self._local_rank = os.environ.get("LOCAL_RANK", -1)
|
||||
self._local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import GroupMember
|
||||
|
||||
|
||||
def prod(nums: List[int]) -> int:
|
||||
|
@ -47,7 +48,7 @@ class ProcessGroupMesh:
|
|||
self._shape = size
|
||||
self._rank = dist.get_rank()
|
||||
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
|
||||
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
|
||||
self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {}
|
||||
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
||||
|
||||
def destroy_mesh_process_groups(self):
|
||||
|
@ -136,7 +137,7 @@ class ProcessGroupMesh:
|
|||
assert mode in ["raise", "wrap", "clip"]
|
||||
return int(np.ravel_multi_index(coord, shape, mode))
|
||||
|
||||
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
|
@ -147,10 +148,11 @@ class ProcessGroupMesh:
|
|||
ProcessGroup: The process group with the given ranks.
|
||||
"""
|
||||
ranks_in_group = sorted(ranks_in_group)
|
||||
if tuple(ranks_in_group) not in self._group_to_ranks:
|
||||
if tuple(ranks_in_group) not in self._ranks_to_group:
|
||||
group = dist.new_group(ranks_in_group, backend=backend)
|
||||
self._ranks_to_group[tuple(ranks_in_group)] = group
|
||||
self._group_to_ranks[group] = tuple(ranks_in_group)
|
||||
if group is not GroupMember.NON_GROUP_MEMBER:
|
||||
self._group_to_ranks[group] = tuple(ranks_in_group)
|
||||
return self._ranks_to_group[tuple(ranks_in_group)]
|
||||
|
||||
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
|
||||
|
@ -238,7 +240,7 @@ class ProcessGroupMesh:
|
|||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
group = self.get_group(ranks_in_group, backend=backend)
|
||||
group = self._get_group(ranks_in_group, backend=backend)
|
||||
if self._rank in ranks_in_group:
|
||||
target_group = group
|
||||
return target_group
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
|
||||
## 📌 Introduction
|
||||
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
|
||||
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
|
||||
|
@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
|
|||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
# Distrifusion
|
||||
@InProceedings{Li_2024_CVPR,
|
||||
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
|
||||
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month={June},
|
||||
year={2024},
|
||||
pages={7183-7193}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
|
@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
|
||||
start_token_size(int): The size of the start tokens, when using StreamingLLM.
|
||||
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
|
||||
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
|
@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
|
|||
start_token_size: int = 4
|
||||
generated_token_size: int = 512
|
||||
|
||||
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
|
||||
patched_parallelism_size: int = 1 # for distrifusion
|
||||
# pipeFusion_m_size: int = 1 # for pipefusion
|
||||
# pipeFusion_n_size: int = 1 # for pipefusion
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
self._verify_config()
|
||||
|
@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
|
|||
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
|
||||
self.start_token_size = self.block_size
|
||||
|
||||
# check Distrifusion
|
||||
# TODO(@lry89757) need more detailed check
|
||||
if self.patched_parallelism_size > 1:
|
||||
# self.use_patched_parallelism = True
|
||||
self.tp_size = (
|
||||
self.patched_parallelism_size
|
||||
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
|
@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_flash_attn=use_flash_attn,
|
||||
patched_parallelism_size=self.patched_parallelism_size,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
|
@ -396,3 +411,50 @@ class ModelShardInferenceConfig:
|
|||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionGenerationConfig:
|
||||
"""
|
||||
Param for diffusion model forward
|
||||
"""
|
||||
|
||||
prompt_2: Optional[Union[str, List[str]]] = None
|
||||
prompt_3: Optional[Union[str, List[str]]] = None
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
num_inference_steps: int = None
|
||||
timesteps: List[int] = None
|
||||
guidance_scale: float = None
|
||||
negative_prompt: Optional[Union[str, List[str]]] = (
|
||||
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
|
||||
)
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None
|
||||
num_images_per_prompt: Optional[int] = None
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
|
||||
latents: Optional[torch.FloatTensor] = None
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
output_type: Optional[str] = None # "pil"
|
||||
return_dict: bool = None
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
clip_skip: Optional[int] = None
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
|
||||
callback_on_step_end_tensor_inputs: List[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# NOTE(@lry89757) Only return the dict that not the default value None
|
||||
result = {}
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is not None:
|
||||
result[field.name] = value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
|
||||
return cls(**kwargs)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
|
||||
"""
|
||||
Init Model for Engine
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
|
||||
"""
|
||||
Generate ouptput for coming requests
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, prompts, request_ids=None, **kwargs):
|
||||
"""
|
||||
Add new request to Engine
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def step(self):
|
||||
"""
|
||||
Perform one new step forward
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _verify_args(self):
|
||||
"""
|
||||
Verify the parameters and members of class
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self):
|
||||
"""
|
||||
Use cuda graph to capture model
|
||||
"""
|
||||
return NotImplementedError("This method should be implemented by subclasses")
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
|
@ -0,0 +1,200 @@
|
|||
from itertools import count
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from torch import distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import DiffusionSequence
|
||||
from colossalai.inference.utils import get_model_size, get_model_type
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .request_handler import NaiveRequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
|
||||
class DiffusionEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: DiffusionPipeline | str,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy | type[Policy] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.request_handler = NaiveRequestHandler()
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[str, nn.Module, DiffusionPipeline],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
if isinstance(model_or_path, str):
|
||||
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
|
||||
policy_map_key = model.__class__.__name__
|
||||
model = DiffusionPipe(model)
|
||||
elif isinstance(model_or_path, DiffusionPipeline):
|
||||
policy_map_key = model_or_path.__class__.__name__
|
||||
model = DiffusionPipe(model_or_path)
|
||||
else:
|
||||
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map.get(policy_map_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = model.to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
generation_config: DiffusionGenerationConfig = None,
|
||||
**kwargs,
|
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||
""" """
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
**gen_config_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output_reqs_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_reqs_list += self.step()
|
||||
|
||||
return output_reqs_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
prompts: Union[List[str], str],
|
||||
request_ids: Union[List[int], int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
|
||||
prompts_num = len(prompts)
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
|
||||
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
|
||||
|
||||
self.request_handler.add_sequence(seq)
|
||||
|
||||
def step(self) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. run forward to get List[Image]
|
||||
Returns:
|
||||
List[PIL.Image.Image]: Image Generated by one step.
|
||||
"""
|
||||
|
||||
input = self.request_handler.schedule()
|
||||
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
|
||||
return ret
|
|
@ -1,57 +1,24 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import PIL.Image
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.utils import ModelType, get_model_type
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
__all__ = ["InferenceEngine"]
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
|
@ -60,567 +27,68 @@ class InferenceEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
model_or_path: Union[nn.Module, str, DiffusionPipeline],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
|
||||
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||
self.engine = None
|
||||
if self.model_type == ModelType.LLM:
|
||||
from .llm_engine import LLMEngine
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
self.engine = LLMEngine(
|
||||
model_or_path=model_or_path,
|
||||
tokenizer=tokenizer,
|
||||
inference_config=inference_config,
|
||||
verbose=verbose,
|
||||
model_policy=model_policy,
|
||||
)
|
||||
elif self.model_type == ModelType.DIFFUSION_MODEL:
|
||||
from .diffusion_engine import DiffusionEngine
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.engine = DiffusionEngine(
|
||||
model_or_path=model_or_path,
|
||||
inference_config=inference_config,
|
||||
verbose=verbose,
|
||||
model_policy=model_policy,
|
||||
)
|
||||
elif self.model_type == ModelType.UNKNOWN:
|
||||
self.logger.error(f"Model Type either Difffusion or LLM!")
|
||||
|
||||
self._initialized = True
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
if arch is "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
high_precision=False,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
||||
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_token_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
assert self._initialized, "Engine must be initialized"
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
for seq in output_seqs_list:
|
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||
|
||||
if return_token_ids:
|
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||
return output_str, output_tokens_list
|
||||
else:
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -630,168 +98,36 @@ class InferenceEngine:
|
|||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
kwargs: for LLM, it could be max_length, max_new_tokens, etc
|
||||
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
|
||||
"""
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
def step(self):
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
return self.engine.step()
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
The Design logic of getattr, setattr:
|
||||
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
|
||||
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
|
||||
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
|
||||
"""
|
||||
if self.__dict__.get("_initialized", False):
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
return getattr(self.engine, name)
|
||||
else:
|
||||
return self.__dict__[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if self.__dict__.get("_initialized", False):
|
||||
if name in self.__dict__:
|
||||
self.__dict__[name] = value
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
setattr(self.engine, name, value)
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
output_tensor = torch.zeros(
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
self.__dict__[name] = value
|
||||
|
|
|
@ -0,0 +1,758 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class LLMEngine(BaseEngine):
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Union[Policy, type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
if arch == "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
high_precision=False,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_token_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
for seq in output_seqs_list:
|
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||
|
||||
if return_token_ids:
|
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||
return output_str, output_tokens_list
|
||||
else:
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
output_tensor = torch.zeros(
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
|
@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
|||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -98,7 +98,46 @@ class RunningList:
|
|||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
class NaiveRequestHandler:
|
||||
def __init__(self) -> None:
|
||||
self.running_list: List[DiffusionSequence] = []
|
||||
self.waiting_list: List[str] = []
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return any(lst for lst in self.running_list)
|
||||
|
||||
def check_unfinished_reqs(self):
|
||||
return self._has_waiting() or self._has_running()
|
||||
|
||||
def add_sequence(self, seq: DiffusionSequence):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
|
||||
self.waiting_list.append(seq)
|
||||
|
||||
def _find_sequence(self, request_id: int) -> DiffusionSequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
for lst in enumerate(self.waiting_list + self.running_list):
|
||||
for seq in lst:
|
||||
if seq.request_id == request_id:
|
||||
return seq
|
||||
return None
|
||||
|
||||
def schedule(self):
|
||||
ret = None
|
||||
if self._has_waiting:
|
||||
ret = self.waiting_list[0]
|
||||
self.waiting_list = self.waiting_list[1:]
|
||||
return ret
|
||||
|
||||
|
||||
class RequestHandler(NaiveRequestHandler):
|
||||
"""
|
||||
RequestHandler is the core for handling existing requests and updating current batch.
|
||||
During generation process, we call schedule function each iteration to update current batch.
|
||||
|
@ -176,12 +215,12 @@ class RequestHandler:
|
|||
generated_token_size=inference_config.generated_token_size,
|
||||
)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return not self.running_bb.is_empty()
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
|
@ -318,7 +357,7 @@ class RequestHandler:
|
|||
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
||||
seq.mark_finished()
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
def check_unfinished_reqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import inspect
|
||||
import types
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DiffusionPipe(nn.Module):
|
||||
"""
|
||||
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
|
||||
"""
|
||||
|
||||
def __init__(self, source_obj) -> None:
|
||||
super(DiffusionPipe, self).__init__()
|
||||
|
||||
for k, v in source_obj.__dict__.items():
|
||||
if isinstance(v, nn.Module):
|
||||
self.add_module(k, v)
|
||||
else:
|
||||
setattr(self, k, v)
|
||||
|
||||
skip_list = ["_execution_device", "to", "device"] # this
|
||||
|
||||
for name, member in inspect.getmembers(source_obj.__class__):
|
||||
if name in skip_list:
|
||||
continue
|
||||
if not name.startswith("__") and not name.endswith("__"):
|
||||
if isinstance(member, property):
|
||||
setattr(self.__class__, name, member)
|
||||
elif inspect.isfunction(member) or inspect.ismethod(member):
|
||||
bound_method = types.MethodType(member, self)
|
||||
setattr(self, name, bound_method)
|
||||
elif not callable(member) and not isinstance(member, property):
|
||||
setattr(self, name, member)
|
||||
elif name == "__call__":
|
||||
bound_method = types.MethodType(member, self)
|
||||
setattr(self, "_forward", bound_method)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||
Accelerate's module hooks.
|
||||
"""
|
||||
# return self.device
|
||||
return torch.device("cuda")
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
next(self.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward(*args, **kwargs)
|
|
@ -0,0 +1,626 @@
|
|||
# Code refer and adapted from:
|
||||
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
|
||||
# https://github.com/PipeFusion/PipeFusion
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models import attention_processor
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
|
||||
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
|
||||
def PixArtAlphaTransformer2DModel_forward(
|
||||
self: PixArtTransformer2DModel,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
assert hasattr(
|
||||
self, "patched_parallel_size"
|
||||
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
|
||||
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size = hidden_states.shape[0]
|
||||
height, width = (
|
||||
hidden_states.shape[-2] // self.config.patch_size,
|
||||
hidden_states.shape[-1] // self.config.patch_size,
|
||||
)
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if self.caption_projection is not None:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
|
||||
2, dim=1
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
height // self.patched_parallel_size,
|
||||
width,
|
||||
self.config.patch_size,
|
||||
self.config.patch_size,
|
||||
self.out_channels,
|
||||
)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
self.out_channels,
|
||||
height // self.patched_parallel_size * self.config.patch_size,
|
||||
width * self.config.patch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# enable Distrifusion Optimization
|
||||
if hasattr(self, "patched_parallel_size"):
|
||||
from torch import distributed as dist
|
||||
|
||||
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
|
||||
self.output_buffer = torch.empty_like(output)
|
||||
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
|
||||
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
|
||||
output = output.contiguous()
|
||||
dist.all_gather(self.buffer_list, output, async_op=False)
|
||||
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
|
||||
output = self.output_buffer
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
|
||||
def SD3Transformer2DModel_forward(
|
||||
self: SD3Transformer2DModel,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor]:
|
||||
|
||||
assert hasattr(
|
||||
self, "patched_parallel_size"
|
||||
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
patch_size = self.config.patch_size
|
||||
height = height // patch_size // self.patched_parallel_size
|
||||
width = width // patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
# enable Distrifusion Optimization
|
||||
if hasattr(self, "patched_parallel_size"):
|
||||
from torch import distributed as dist
|
||||
|
||||
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
|
||||
self.output_buffer = torch.empty_like(output)
|
||||
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
|
||||
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
|
||||
output = output.contiguous()
|
||||
dist.all_gather(self.buffer_list, output, async_op=False)
|
||||
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
|
||||
output = self.output_buffer
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
|
||||
class DistrifusionPatchEmbed(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
module: PatchEmbed,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.rank = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
distrifusion_embed = DistrifusionPatchEmbed(
|
||||
module, process_group, model_shard_infer_config=model_shard_infer_config
|
||||
)
|
||||
return distrifusion_embed
|
||||
|
||||
def forward(self, latent):
|
||||
module = self.module
|
||||
if module.pos_embed_max_size is not None:
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
|
||||
|
||||
latent = module.proj(latent)
|
||||
if module.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if module.layer_norm:
|
||||
latent = module.norm(latent)
|
||||
if module.pos_embed is None:
|
||||
return latent.to(latent.dtype)
|
||||
# Interpolate or crop positional embeddings as needed
|
||||
if module.pos_embed_max_size:
|
||||
pos_embed = module.cropped_pos_embed(height, width)
|
||||
else:
|
||||
if module.height != height or module.width != width:
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim=module.pos_embed.shape[-1],
|
||||
grid_size=(height, width),
|
||||
base_size=module.base_size,
|
||||
interpolation_scale=module.interpolation_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
||||
else:
|
||||
pos_embed = module.pos_embed
|
||||
|
||||
b, c, h = pos_embed.shape
|
||||
pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
|
||||
|
||||
return (latent + pos_embed).to(latent.dtype)
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
|
||||
class DistrifusionConv2D(ParallelModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Conv2d,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.rank = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
|
||||
return distrifusion_conv
|
||||
|
||||
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
b, c, h, w = x.shape
|
||||
|
||||
stride = self.module.stride[0]
|
||||
padding = self.module.padding[0]
|
||||
|
||||
output_h = x.shape[2] // stride // self.patched_parallelism_size
|
||||
idx = dist.get_rank()
|
||||
h_begin = output_h * idx * stride - padding
|
||||
h_end = output_h * (idx + 1) * stride + padding
|
||||
final_padding = [padding, padding, 0, 0]
|
||||
if h_begin < 0:
|
||||
h_begin = 0
|
||||
final_padding[2] = padding
|
||||
if h_end > h:
|
||||
h_end = h
|
||||
final_padding[3] = padding
|
||||
sliced_input = x[:, :, h_begin:h_end, :]
|
||||
padded_input = F.pad(sliced_input, final_padding, mode="constant")
|
||||
return F.conv2d(
|
||||
padded_input,
|
||||
self.module.weight,
|
||||
self.module.bias,
|
||||
stride=stride,
|
||||
padding="valid",
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output = self.sliced_forward(input)
|
||||
return output
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
|
||||
class DistrifusionFusedAttention(ParallelModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: attention_processor.Attention,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
self.module = module
|
||||
self.buffer_list = None
|
||||
self.kv_buffer_idx = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
self.handle = None
|
||||
self.process_group = process_group
|
||||
self.warm_step = 5 # for warmup
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
return DistrifusionFusedAttention(
|
||||
module=module,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
|
||||
|
||||
if self.patched_parallelism_size == 1:
|
||||
full_kv = kv
|
||||
else:
|
||||
if self.buffer_list is None: # buffer not created
|
||||
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
|
||||
elif self.counter <= self.warm_step:
|
||||
# logger.info(f"warmup: {self.counter}")
|
||||
dist.all_gather(
|
||||
self.buffer_list,
|
||||
kv,
|
||||
group=self.process_group,
|
||||
async_op=False,
|
||||
)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
else:
|
||||
# logger.info(f"use old kv to infer: {self.counter}")
|
||||
self.buffer_list[self.kv_buffer_idx].copy_(kv)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
assert self.handle is None, "we should maintain the kv of last step"
|
||||
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
|
||||
|
||||
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.handle is not None:
|
||||
self.handle.wait()
|
||||
self.handle = None
|
||||
|
||||
b, l, c = hidden_states.shape
|
||||
kv_shape = (b, l, self.module.to_k.out_features * 2)
|
||||
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
|
||||
|
||||
self.buffer_list = [
|
||||
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
|
||||
for _ in range(self.patched_parallelism_size)
|
||||
]
|
||||
|
||||
self.counter = 0
|
||||
|
||||
attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
output = self._forward(
|
||||
self.module,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
self.counter += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
|
||||
class DistriSelfAttention(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
module: Attention,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
self.module = module
|
||||
self.buffer_list = None
|
||||
self.kv_buffer_idx = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
self.handle = None
|
||||
self.process_group = process_group
|
||||
self.warm_step = 3 # for warmup
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
return DistriSelfAttention(
|
||||
module=module,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
|
||||
attn = self.module
|
||||
assert isinstance(attn, Attention)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states
|
||||
k = self.module.to_k(encoder_hidden_states)
|
||||
v = self.module.to_v(encoder_hidden_states)
|
||||
kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
|
||||
|
||||
if self.patched_parallelism_size == 1:
|
||||
full_kv = kv
|
||||
else:
|
||||
if self.buffer_list is None: # buffer not created
|
||||
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
|
||||
elif self.counter <= self.warm_step:
|
||||
# logger.info(f"warmup: {self.counter}")
|
||||
dist.all_gather(
|
||||
self.buffer_list,
|
||||
kv,
|
||||
group=self.process_group,
|
||||
async_op=False,
|
||||
)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
else:
|
||||
# logger.info(f"use old kv to infer: {self.counter}")
|
||||
self.buffer_list[self.kv_buffer_idx].copy_(kv)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
assert self.handle is None, "we should maintain the kv of last step"
|
||||
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
# flash attn
|
||||
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
|
||||
else:
|
||||
# naive attn
|
||||
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
# async preallocates memo buffer
|
||||
if self.handle is not None:
|
||||
self.handle.wait()
|
||||
self.handle = None
|
||||
|
||||
b, l, c = hidden_states.shape
|
||||
kv_shape = (b, l, self.module.to_k.out_features * 2)
|
||||
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
|
||||
|
||||
self.buffer_list = [
|
||||
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
|
||||
for _ in range(self.patched_parallelism_size)
|
||||
]
|
||||
|
||||
self.counter = 0
|
||||
|
||||
output = self._forward(hidden_states, scale=scale)
|
||||
|
||||
self.counter += 1
|
||||
return output
|
|
@ -0,0 +1,220 @@
|
|||
# Code adapted from:
|
||||
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
|
||||
ASPECT_RATIO_256_BIN,
|
||||
ASPECT_RATIO_512_BIN,
|
||||
ASPECT_RATIO_1024_BIN,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from ..layers.diffusion import DiffusionPipe
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def pixart_alpha_forward(
|
||||
self: DiffusionPipe,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
) -> PIL.Image:
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
if use_resolution_binning:
|
||||
if self.transformer.config.sample_size == 128:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
elif self.transformer.config.sample_size == 64:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
elif self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Prepare micro-conditions.
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.transformer.config.sample_size == 128:
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
||||
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
resolution = torch.cat([resolution, resolution], dim=0)
|
||||
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
||||
|
||||
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=current_timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
if num_inference_steps == 1:
|
||||
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# self.maybe_free_model_hooks()
|
||||
|
||||
return image
|
|
@ -0,0 +1,178 @@
|
|||
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||
|
||||
from ..layers.diffusion import DiffusionPipe
|
||||
|
||||
|
||||
# TODO(@lry89757) temporarily image, please support more return output
|
||||
@torch.no_grad()
|
||||
def sd3_forward(
|
||||
self: DiffusionPipe,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
return image
|
|
@ -1,16 +1,22 @@
|
|||
from .glide_llama import GlideLlamaModelPolicy
|
||||
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
|
||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||
from .pixart_alpha import PixArtAlphaInferPolicy
|
||||
from .stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
model_policy_map = {
|
||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
|
||||
"glide_llama": GlideLlamaModelPolicy,
|
||||
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
|
||||
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"NoPaddingLlamaModelInferPolicy",
|
||||
"NoPaddingBaichuanModelInferPolicy",
|
||||
"GlideLlamaModelPolicy",
|
||||
"StableDiffusion3InferPolicy",
|
||||
"PixArtAlphaInferPolicy",
|
||||
"model_polic_map",
|
||||
]
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue