diff --git a/.compatibility b/.compatibility index d90a74b58..4f808740b 100644 --- a/.compatibility +++ b/.compatibility @@ -1 +1,3 @@ 2.1.0-12.1.0 +2.2.2-12.1.0 +2.3.0-12.1.0 diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 3eee564c2..1a458d7bb 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -55,41 +55,27 @@ jobs: steps: - name: Install dependencies run: | - pip install -U pip setuptools==68.2.2 wheel --user - - uses: actions/checkout@v2 - with: - repository: hpcaitech/TensorNVMe - ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - path: TensorNVMe - - name: Install tensornvme - run: | - cd TensorNVMe apt update && apt install -y cmake - pip install -r requirements.txt - DISABLE_URING=1 pip install -v . + pip install -U pip setuptools==68.2.2 wheel --user + - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Download cub for CUDA 10.2 - run: | - CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') - # check if it is CUDA 10.2 - # download cub - if [ "$CUDA_VERSION" = "10.2" ]; then - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ - fi - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt + + - name: Install tensornvme + run: | + DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git + - name: Unit Testing run: | PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index b418c843e..770f4b933 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -49,42 +49,27 @@ jobs: steps: - name: Install dependencies run: | - pip install -U pip setuptools==68.2.2 wheel --user - - uses: actions/checkout@v2 - with: - repository: hpcaitech/TensorNVMe - ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - path: TensorNVMe - - name: Install tensornvme - run: | - cd TensorNVMe apt update && apt install -y cmake - pip install -r requirements.txt - DISABLE_URING=1 pip install -v . + pip install -U pip setuptools==68.2.2 wheel --user + - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Download cub for CUDA 10.2 - run: | - CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') - - # check if it is CUDA 10.2 - # download cub - if [ "$CUDA_VERSION" = "10.2" ]; then - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ - fi - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt + + - name: Install tensornvme + run: | + DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git + - name: Unit Testing run: | PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 8d98e775c..c6455604f 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -43,47 +43,28 @@ jobs: steps: - name: Install dependencies run: | + apt update && apt install -y cmake pip install -U pip setuptools==68.2.2 wheel --user - uses: actions/checkout@v2 with: - repository: hpcaitech/TensorNVMe ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - path: TensorNVMe - - - name: Install tensornvme - run: | - cd TensorNVMe - apt update && apt install -y cmake - pip install -r requirements.txt - DISABLE_URING=1 pip install -v . - - uses: actions/checkout@v2 - with: - ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - - name: Download cub for CUDA 10.2 - run: | - CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') - - # check if it is CUDA 10.2 - # download cub - if [ "$CUDA_VERSION" = "10.2" ]; then - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ - fi - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt + + - name: Install tensornvme + run: | + DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git - name: Unit Testing run: | PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index 4ea86b609..d0b5c2164 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -52,6 +52,7 @@ jobs: mkdir sft_data mkdir prompt_data mkdir preference_data + mkdir kto_data ./tests/test_data_preparation.sh ./tests/test_train.sh env: @@ -61,3 +62,4 @@ jobs: SFT_DATASET: ./sft_data PROMPT_DATASET: ./prompt_data PREFERENCE_DATASET: ./preference_data + KTO_DATASET: ./kto_data diff --git a/applications/Colossal-LLaMA/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py index a857d6c0c..fe5790760 100644 --- a/applications/Colossal-LLaMA/prepare_sft_dataset.py +++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py @@ -10,7 +10,7 @@ import math import os from multiprocessing import cpu_count -from colossal_llama.dataset.conversation import LLaMA2_Conv +from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft from datasets import dataset_dict, load_dataset from transformers import AddedToken, AutoTokenizer @@ -75,6 +75,8 @@ def main(): # Prepare to the tokenizer. tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + default_conversation = LLaMA3_Conv + # Fix split issue: https://github.com/huggingface/transformers/issues/23833 if args.llama_version == 2: tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 43a360a9a..e74aad33c 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -128,6 +128,12 @@ def main() -> None: parser.add_argument("--zero", type=int, default=1) parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") + parser.add_argument( + "--skip_save_each_epoch", + action="store_true", + default=False, + help="skip saving the model checkpoint after each epoch is completed.", + ) args = parser.parse_args() with open(args.config_file, "w") as f: @@ -370,11 +376,17 @@ def main() -> None: ) total_loss.fill_(0.0) pbar.update() + # Save modeling. - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): + save_model_condition = ( + args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + ) + + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) + + if save_model_condition: coordinator.print_on_master("\nStart saving model checkpoint with running states") if args.use_neft: diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 33950adc0..757cbb5da 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -146,6 +146,9 @@ docs/.build examples/wandb/ examples/logs/ examples/output/ +examples/training_scripts/logs +examples/training_scripts/wandb +examples/training_scripts/output examples/awesome-chatgpt-prompts/ temp/ diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md index 769f0b3d0..de27ebaf6 100755 --- a/applications/ColossalChat/README.md +++ b/applications/ColossalChat/README.md @@ -23,6 +23,10 @@ - [Open QA](#open-qa) - [Limitation for LLaMA-finetuned models](#limitation) - [Limitation of dataset](#limitation) +- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization) +- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo) +- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo) +- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto) - [FAQ](#faq) - [How to save/load checkpoint](#faq) - [How to train with limited resources](#faq) @@ -135,17 +139,15 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the {"messages": [ { - "from": "human", + "from": "user", "content": "what are some pranks with a pen i can do?" }, { "from": "assistant", "content": "Are you looking for practical joke ideas?" }, - ... ] }, - ... ] ``` @@ -171,23 +173,20 @@ Below shows the preference dataset format used in training the reward model. "from": "human", "content": "Introduce butterflies species in Oregon." } - ] + ], "chosen": [ { "from": "assistant", "content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..." }, - ... ], "rejected": [ { "from": "assistant", "content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..." }, - ... ] }, - ... ] ``` @@ -216,7 +215,6 @@ PPO uses two kind of training data--- the prompt data and the sft data (optional "from": "human", "content": "what are some pranks with a pen i can do?" } - ... ] }, ] @@ -262,9 +260,8 @@ experience buffer size = train_batch_size * accumulation_steps * num_tp_group ``` -## Alternative Option For RLHF: Direct Preference Optimization - -For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. +## Alternative Option For RLHF: Direct Preference Optimization (DPO) +For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information. ### DPO Training Stage1 - Supervised Instructs Tuning @@ -277,6 +274,15 @@ For DPO training, you only need the preference dataset. Please follow the instru #### Step 2: Training You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md). +## Alternative Option For RLHF: Simple Preference Optimization (SimPO) +Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information. + +## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO) +Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information. + +## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO) +We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information. + ### Inference Quantization and Serving - After Training We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. @@ -441,20 +447,6 @@ If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have. -## The Plan - -- [x] implement PPO fine-tuning -- [x] implement training reward model -- [x] support LoRA -- [x] support inference -- [x] support llama from [facebook](https://github.com/facebookresearch/llama) -- [x] implement PPO-ptx fine-tuning -- [x] support flash-attention -- [x] implement DPO fine-tuning -- [ ] integrate with Ray -- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL), -- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain) - ### Real-time progress You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1). @@ -522,7 +514,7 @@ Coati is developed by ColossalAI Team: - [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT. - [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development. - [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements. -- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO. +- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO. The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. - [Zangwei Zheng](https://github.com/zhengzangw) @@ -572,6 +564,36 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github journal = {GitHub repository}, howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}}, } + +@misc{meng2024simposimplepreferenceoptimization, + title={SimPO: Simple Preference Optimization with a Reference-Free Reward}, + author={Yu Meng and Mengzhou Xia and Danqi Chen}, + year={2024}, + eprint={2405.14734}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2405.14734}, +} + +@misc{rafailov2023directpreferenceoptimizationlanguage, + title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model}, + author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn}, + year={2023}, + eprint={2305.18290}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2305.18290}, +} + +@misc{hong2024orpomonolithicpreferenceoptimization, + title={ORPO: Monolithic Preference Optimization without Reference Model}, + author={Jiwoo Hong and Noah Lee and James Thorne}, + year={2024}, + eprint={2403.07691}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2403.07691}, +} ``` ## Licenses diff --git a/applications/ColossalChat/benchmarks/benchmark_dpo.sh b/applications/ColossalChat/benchmarks/benchmark_dpo.sh new file mode 100755 index 000000000..44d821a87 --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_dpo.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +PROJECT_NAME="dpo" +PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data +DATASET_SIZE=320 + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +declare -a dataset=( + $BENCHMARK_DATA_DIR/arrow/part-0 +) + +# Generate dummy test data +python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference + + +colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2_cpu" \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --batch_size 4 \ + --lr 1e-6 \ + --beta 0.1 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --max_length 2048 \ + --weight_decay 0.01 \ + --warmup_steps 60 \ + --grad_checkpoint \ + --use_flash_attn diff --git a/applications/ColossalChat/benchmarks/benchmark_kto.sh b/applications/ColossalChat/benchmarks/benchmark_kto.sh new file mode 100755 index 000000000..82d3e3421 --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_kto.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +PROJECT_NAME="kto" +PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data +DATASET_SIZE=80 + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +declare -a dataset=( + $BENCHMARK_DATA_DIR/arrow/part-0 +) + +# Generate dummy test data +python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto + + +colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2_cpu" \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --batch_size 2 \ + --lr 1e-5 \ + --beta 0.1 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --max_length 2048 \ + --weight_decay 0.01 \ + --warmup_steps 60 \ + --grad_checkpoint \ + --use_flash_attn diff --git a/applications/ColossalChat/benchmarks/benchmark_orpo.sh b/applications/ColossalChat/benchmarks/benchmark_orpo.sh new file mode 100755 index 000000000..f8fb264ae --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_orpo.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +PROJECT_NAME="orpo" +PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data +DATASET_SIZE=160 + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +declare -a dataset=( + $BENCHMARK_DATA_DIR/arrow/part-0 +) + +# Generate dummy test data +python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference + + +colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_orpo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --batch_size 4 \ + --lr 8e-6 \ + --lam 0.5 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --max_length 2048 \ + --weight_decay 0.01 \ + --warmup_steps 60 \ + --grad_checkpoint \ + --use_flash_attn diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.sh b/applications/ColossalChat/benchmarks/benchmark_sft.sh new file mode 100755 index 000000000..efcd428dd --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_sft.sh @@ -0,0 +1,50 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +PROJECT_NAME="sft" +PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data +DATASET_SIZE=640 + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +declare -a dataset=( + $BENCHMARK_DATA_DIR/arrow/part-0 +) + + +# Generate dummy test data +python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type sft + + +# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size +colossalai run --nproc_per_node 1 --master_port 31312 ../examples/training_scripts/train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin zero2 \ + --batch_size 8 \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --lr 5e-5 \ + --lora_rank 32 \ + --max_len 2048 \ + --grad_checkpoint \ + --use_flash_attn diff --git a/applications/ColossalChat/benchmarks/benchmark_simpo.sh b/applications/ColossalChat/benchmarks/benchmark_simpo.sh new file mode 100755 index 000000000..47dfc8595 --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_simpo.sh @@ -0,0 +1,55 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +PROJECT_NAME="simpo" +PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data +DATASET_SIZE=640 + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +declare -a dataset=( + $BENCHMARK_DATA_DIR/arrow/part-0 +) + +# Generate dummy test data +python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference + + +colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2_cpu" \ + --loss_type "simpo_loss" \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --batch_size 8 \ + --lr 1e-6 \ + --beta 0.1 \ + --gamma 0.6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --max_length 2048 \ + --weight_decay 0.01 \ + --warmup_steps 60 \ + --disable_reference_model \ + --length_normalization \ + --grad_checkpoint \ + --use_flash_attn diff --git a/applications/ColossalChat/benchmarks/dummy_dataset.py b/applications/ColossalChat/benchmarks/dummy_dataset.py new file mode 100644 index 000000000..9af0f1641 --- /dev/null +++ b/applications/ColossalChat/benchmarks/dummy_dataset.py @@ -0,0 +1,30 @@ +from typing import Callable + +from torch.utils.data import Dataset + + +class DummyLLMDataset(Dataset): + def __init__(self, keys, seq_len, size=500, gen_fn={}): + self.keys = keys + self.gen_fn = gen_fn + self.seq_len = seq_len + self.data = self._generate_data() + self.size = size + + def _generate_data(self): + data = {} + for key in self.keys: + if key in self.gen_fn: + data[key] = self.gen_fn[key] + else: + data[key] = [1] * self.seq_len + return data + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return { + key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx) + for key in self.keys + } diff --git a/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py new file mode 100644 index 000000000..f501c5358 --- /dev/null +++ b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py @@ -0,0 +1,105 @@ +import argparse +import json +import os +import time +from multiprocessing import cpu_count + +from datasets import load_dataset +from dummy_dataset import DummyLLMDataset + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + required=True, + default=None, + help="The output dir", + ) + parser.add_argument( + "--dataset_size", + type=int, + required=True, + default=None, + help="The size of data", + ) + parser.add_argument( + "--max_length", + type=int, + required=True, + default=None, + help="The max length of data", + ) + parser.add_argument( + "--data_type", + type=str, + required=True, + default=None, + help="The type of data, choose one from ['sft', 'prompt', 'preference', 'kto']", + ) + args = parser.parse_args() + if args.data_type == "sft": + dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_length, args.dataset_size) + elif args.data_type == "prompt": + # pass PPO dataset is prepared separately + pass + elif args.data_type == "preference": + dataset = DummyLLMDataset( + ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"], + args.max_length, + args.dataset_size, + ) + elif args.data_type == "kto": + dataset = DummyLLMDataset( + ["prompt", "completion", "label"], + args.max_length - 512, + args.dataset_size, + gen_fn={ + "completion": lambda x: [1] * 512, + "label": lambda x: x % 2, + }, + ) + else: + raise ValueError(f"Unknown data type {args.data_type}") + + # Save each jsonl spliced dataset. + output_index = "0" + output_name = f"part-{output_index}" + os.makedirs(args.data_dir, exist_ok=True) + output_jsonl_path = os.path.join(args.data_dir, "json") + output_arrow_path = os.path.join(args.data_dir, "arrow") + output_cache_path = os.path.join(args.data_dir, "cache") + os.makedirs(output_jsonl_path, exist_ok=True) + os.makedirs(output_arrow_path, exist_ok=True) + output_jsonl_file_path = os.path.join(output_jsonl_path, output_name + ".jsonl") + st = time.time() + with open(file=output_jsonl_file_path, mode="w", encoding="utf-8") as fp_writer: + count = 0 + for i in range(len(dataset)): + data_point = dataset[i] + if count % 500 == 0: + logger.info(f"processing {count} spliced data points for {fp_writer.name}") + count += 1 + fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n") + logger.info( + f"Current file {fp_writer.name}; " + f"Data size: {len(dataset)}; " + f"Time cost: {round((time.time() - st) / 60, 6)} minutes." + ) + # Save each arrow spliced dataset + output_arrow_file_path = os.path.join(output_arrow_path, output_name) + logger.info(f"Start to save {output_arrow_file_path}") + dataset = load_dataset( + path="json", + data_files=[output_jsonl_file_path], + cache_dir=os.path.join(output_cache_path, "tokenized"), + keep_in_memory=False, + num_proc=cpu_count(), + split="train", + ) + dataset.save_to_disk(dataset_path=output_arrow_file_path, num_proc=min(len(dataset), cpu_count())) diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py index deb7b6d92..8e9060a1a 100755 --- a/applications/ColossalChat/coati/dataset/__init__.py +++ b/applications/ColossalChat/coati/dataset/__init__.py @@ -1,24 +1,26 @@ from .conversation import Conversation, setup_conversation_template from .loader import ( + DataCollatorForKTODataset, DataCollatorForPreferenceDataset, DataCollatorForPromptDataset, DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, ) -from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf +from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft __all__ = [ - "tokenize_prompt_dataset", + "tokenize_prompt", "DataCollatorForPromptDataset", "is_rank_0", "DataCollatorForPreferenceDataset", "DataCollatorForSupervisedDataset", + "DataCollatorForKTODataset", "StatefulDistributedSampler", "load_tokenized_dataset", - "supervised_tokenize_pretrain", - "supervised_tokenize_sft", + "tokenize_sft", "tokenize_rlhf", + "tokenize_kto", "setup_conversation_template", "Conversation", ] diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py index 37900f3b8..a77c220d3 100755 --- a/applications/ColossalChat/coati/dataset/conversation.py +++ b/applications/ColossalChat/coati/dataset/conversation.py @@ -18,6 +18,7 @@ class Conversation: chat_template: str stop_ids: List[int] end_of_assistant: str + roles = ["user", "assistant"] @classmethod def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict): @@ -85,7 +86,7 @@ class Conversation: Raises: AssertionError: If the role is not 'user' or 'assistant'. """ - assert role in ["user", "assistant"] + assert role in self.roles self.messages.append({"role": role, "content": message}) def copy(self): diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index a0cd17bb4..b92cd76ad 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -28,6 +28,8 @@ def load_tokenized_dataset( Each instance of dataset is a dictionary with `{'input_ids': List[int], 'labels': List[int], sequence: str}` format. """ + if not dataset_paths: + return None mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"}) assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" @@ -233,6 +235,91 @@ class DataCollatorForPreferenceDataset(object): ) +@dataclass +class DataCollatorForKTODataset(object): + """ + Collate instances for kto dataset. + Each input instance is a tokenized dictionary with fields + `prompt`(List[int]), `completion`(List[int]) and `label`(bool). + Each output instance is a tokenized dictionary with fields + `kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]). + `input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool). + """ + + tokenizer: PreTrainedTokenizer + max_length: int = 4096 + ignore_index: int = -100 + + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary contains the following fields: + `prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not). + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. + """ + assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( + f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " + f"but now `{self.tokenizer.pad_token_id}`" + ) + # prepare the preference data + prompt = [torch.LongTensor(instance["prompt"]) for instance in instances] + prompt_zeros = [torch.zeros_like(t) for t in prompt] + completion = [torch.LongTensor(instance["completion"]) for instance in instances] + completion_ones = [torch.ones_like(t) for t in completion] + label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances] + input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))] + loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))] + # right padding + input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + loss_mask = torch.nn.utils.rnn.pad_sequence( + sequences=loss_mask, batch_first=True, padding_value=0 + ) # (bsz, max_len) + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + loss_mask = F.pad(loss_mask, (0, to_pad), value=0) + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) + + # prepare kt data + kl_completion = completion[::-1] # y' + kl_completion_ones = [torch.ones_like(t) for t in kl_completion] + kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))] + kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))] + # right padding + kl_input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=kl_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + kl_loss_mask = torch.nn.utils.rnn.pad_sequence( + sequences=kl_loss_mask, batch_first=True, padding_value=0 + ) # (bsz, max_len) + to_pad = self.max_length - kl_input_ids.size(1) + kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0) + kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) + data_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "label": torch.stack(label), + "kl_input_ids": kl_input_ids, + "kl_attention_mask": kl_attention_mask, + "kl_loss_mask": kl_loss_mask, + } + return data_dict + + class StatefulDistributedSampler(DistributedSampler): def __init__( self, diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 34828cbaf..9eb2eba87 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -23,11 +23,10 @@ IGNORE_INDEX = -100 DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] -def supervised_tokenize_sft( +def tokenize_sft( data_point: Dict[str, str], tokenizer: PreTrainedTokenizer, conversation_template: Conversation = None, - ignore_index: int = None, max_length: int = 4096, ) -> Dict[str, Union[int, str, List[int]]]: """ @@ -39,51 +38,37 @@ def supervised_tokenize_sft( Args: data_point: the data point of the following format - {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} + {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} tokenizer: the tokenizer whose conversation_template: the conversation template to apply ignore_index: the ignore index when calculate loss during training max_length: the maximum context length """ - if ignore_index is None: - ignore_index = IGNORE_INDEX + ignore_index = IGNORE_INDEX messages = data_point["messages"] template = deepcopy(conversation_template) template.messages = [] - - for mess in messages: - from_str = mess["from"] - if from_str.lower() == "human": - from_str = "user" - elif from_str.lower() == "assistant": - from_str = "assistant" - else: - raise ValueError(f"Unsupported role {from_str.lower()}") - - template.append_message(from_str, mess["content"]) + for idx, mess in enumerate(messages): + if mess["from"] != template.roles[idx % 2]: + raise ValueError( + f"Message should iterate between user and assistant and starts with a \ + line from the user. Got the following data:\n{messages}" + ) + template.append_message(mess["from"], mess["content"]) if len(template.messages) % 2 != 0: + # Force to end with assistant response template.messages = template.messages[0:-1] - # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. - turns = [i for i in range(1, len(messages) // 2 + 1)] - - lo, hi = 0, len(turns) - while lo < hi: - mid = (lo + hi) // 2 - if max_length - 1 < len( - tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0] - ): - hi = mid - else: - lo = mid + 1 - target_turn_index = lo - - # The tokenized length for first turn already exceeds `max_length - 1`. - if target_turn_index - 1 < 0: - warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.") + # tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines + prompt = template.get_prompt() + chunks, require_loss = split_templated_prompt_into_chunks( + template.messages, prompt, conversation_template.end_of_assistant + ) + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length) + if tokenized is None: return dict( input_ids=None, labels=None, @@ -93,44 +78,18 @@ def supervised_tokenize_sft( seq_category=None, ) - target_turn = turns[target_turn_index - 1] - prompt = template.get_prompt(2 * target_turn) - chunks, require_loss = split_templated_prompt_into_chunks( - template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant - ) - tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) - labels = [ignore_index] * len(tokenized) for start, end in zip(starts, ends): - if end == len(tokenized): - tokenized = tokenized + [tokenizer.eos_token_id] - labels = labels + [ignore_index] labels[start:end] = tokenized[start:end] - # truncate the sequence at the last token that requires loss calculation - to_truncate_len = 0 - for i in range(len(tokenized) - 1, -1, -1): - if labels[i] == ignore_index: - to_truncate_len += 1 - else: - break - tokenized = tokenized[: len(tokenized) - to_truncate_len] - labels = labels[: len(labels) - to_truncate_len] - if tokenizer.bos_token_id is not None: + # Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos if tokenized[0] != tokenizer.bos_token_id: + # Some chat templates already include bos token tokenized = [tokenizer.bos_token_id] + tokenized - labels = [ignore_index] + labels + labels = [-100] + labels - if tokenizer.eos_token_id is not None: - # Force to add eos token at the end of the tokenized sequence - if tokenized[-1] != tokenizer.eos_token_id: - tokenized = tokenized + [tokenizer.eos_token_id] - labels = labels + [tokenizer.eos_token_id] - else: - labels[-1] = tokenizer.eos_token_id - - # For some model without bos/eos may raise the following errors + # log decoded inputs and labels for debugging inputs_decode = tokenizer.decode(tokenized) start = 0 end = 0 @@ -167,11 +126,10 @@ def supervised_tokenize_sft( ) -def tokenize_prompt_dataset( +def tokenize_prompt( data_point: Dict[str, str], tokenizer: PreTrainedTokenizer, conversation_template: Conversation = None, - ignore_index: int = None, max_length: int = 4096, ) -> Dict[str, Union[int, str, List[int]]]: """ @@ -179,48 +137,39 @@ def tokenize_prompt_dataset( "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]" Args: data_point: the data point of the following format - {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} + {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} tokenizer: the tokenizer whose conversation_template: the conversation template to apply ignore_index: the ignore index when calculate loss during training max_length: the maximum context length """ - if ignore_index is None: - ignore_index = IGNORE_INDEX messages = data_point["messages"] template = deepcopy(conversation_template) template.messages = [] - for mess in messages: - from_str = mess["from"] - if from_str.lower() == "human": - from_str = "user" - elif from_str.lower() == "assistant": - from_str = "assistant" - else: - raise ValueError(f"Unsupported role {from_str.lower()}") - - template.append_message(from_str, mess["content"]) + for idx, mess in enumerate(messages): + if mess["from"] != template.roles[idx % 2]: + raise ValueError( + f"Message should iterate between user and assistant and starts with a \ + line from the user. Got the following data:\n{messages}" + ) + template.append_message(mess["from"], mess["content"]) # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. - target_turn = len(template.messages) - if target_turn % 2 != 1: + if len(template.messages) % 2 != 1: # exclude the answer if provided. keep only the prompt - target_turn = target_turn - 1 + template.messages = template.messages[:-1] # Prepare data - prompt = template.get_prompt(target_turn, add_generation_prompt=True) - chunks, require_loss = split_templated_prompt_into_chunks( - template.messages[:target_turn], prompt, conversation_template.end_of_assistant - ) - tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) + prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True) + tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] + if tokenizer.bos_token_id is not None: if tokenized[0] != tokenizer.bos_token_id: tokenized = [tokenizer.bos_token_id] + tokenized - # Skip overlength data - if max_length - 1 < len(tokenized): + if len(tokenized) > max_length: return dict( input_ids=None, inputs_decode=None, @@ -231,47 +180,32 @@ def tokenize_prompt_dataset( # `inputs_decode` can be used to check whether the tokenization method is true. return dict( input_ids=tokenized, - inputs_decode=tokenizer.decode(tokenized), + inputs_decode=prompt, seq_length=len(tokenized), seq_category=data_point["category"] if "category" in data_point else "None", ) -def apply_rlhf_data_format( - template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False -): +def apply_rlhf_data_format(template: Conversation, tokenizer: Any): target_turn = int(len(template.messages) / 2) prompt = template.get_prompt(target_turn * 2) chunks, require_loss = split_templated_prompt_into_chunks( template.messages[: 2 * target_turn], prompt, template.end_of_assistant ) - tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) - loss_mask = [0] * len(tokenized) - mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id - if mask_token is None: - mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen + # no truncation applied + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None) + loss_mask = [0] * len(tokenized) label_decode = [] - for start, end in zip(starts[-1:], ends[-1:]): - # only the last round (chosen/rejected) counts - if end == len(tokenized): - tokenized = tokenized + [tokenizer.eos_token_id] - loss_mask = loss_mask + [1] - loss_mask[start:end] = [1] * len(loss_mask[start:end]) - label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False)) + # only the last round (chosen/rejected) is used to calculate loss + for i in range(starts[-1], ends[-1]): + loss_mask[i] = 1 + label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False)) if tokenizer.bos_token_id is not None: if tokenized[0] != tokenizer.bos_token_id: tokenized = [tokenizer.bos_token_id] + tokenized loss_mask = [0] + loss_mask - if tokenizer.eos_token_id is not None: - # Force to add eos token at the end of the tokenized sequence - if tokenized[-1] != tokenizer.eos_token_id: - tokenized = tokenized + [tokenizer.eos_token_id] - loss_mask = loss_mask + [1] - else: - loss_mask[-1] = 1 - return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode} @@ -279,39 +213,29 @@ def tokenize_rlhf( data_point: Dict[str, str], tokenizer: PreTrainedTokenizer, conversation_template: Conversation = None, - ignore_index: int = None, max_length: int = 4096, ) -> Dict[str, Union[int, str, List[int]]]: """ A tokenization function to tokenize an original pretraining data point as following: - {"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}], + {"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}], "chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}} """ - if ignore_index is None: - ignore_index = IGNORE_INDEX context = data_point["context"] template = deepcopy(conversation_template) template.clear() - for mess in context: - from_str = mess["from"] - if from_str.lower() == "human": - from_str = "user" - elif from_str.lower() == "assistant": - from_str = "assistant" - else: - raise ValueError(f"Unsupported role {from_str.lower()}") - - if len(template.messages) > 0 and from_str == template.messages[-1]["role"]: - # Concate adjacent message from the same role - template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"]) - else: - template.append_message(from_str, mess["content"]) + for idx, mess in enumerate(context): + if mess["from"] != template.roles[idx % 2]: + raise ValueError( + f"Message should iterate between user and assistant and starts with a \ + line from the user. Got the following data:\n{context}" + ) + template.append_message(mess["from"], mess["content"]) if len(template.messages) % 2 != 1: warnings.warn( - "Please make sure leading context starts and ends with a line from human\nLeading context: " + "Please make sure leading context starts and ends with a line from user\nLeading context: " + str(template.messages) ) return dict( @@ -322,31 +246,27 @@ def tokenize_rlhf( rejected_loss_mask=None, rejected_label_decode=None, ) - round_of_context = int((len(template.messages) - 1) / 2) - assert context[-1]["from"].lower() == "human", "The last message in context should be from human." + assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user." chosen = deepcopy(template) rejected = deepcopy(template) + chosen_continuation = data_point["chosen"] + rejected_continuation = data_point["rejected"] + for round in range(len(chosen_continuation)): + if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]: + raise ValueError( + f"Message should iterate between user and assistant and starts with a \ + line from the user. Got the following data:\n{chosen_continuation}" + ) + chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"]) - for round in range(len(data_point["chosen"])): - from_str = data_point["chosen"][round]["from"] - if from_str.lower() == "human": - from_str = "user" - elif from_str.lower() == "assistant": - from_str = "assistant" - else: - raise ValueError(f"Unsupported role {from_str.lower()}") - chosen.append_message(from_str, data_point["chosen"][round]["content"]) - - for round in range(len(data_point["rejected"])): - from_str = data_point["rejected"][round]["from"] - if from_str.lower() == "human": - from_str = "user" - elif from_str.lower() == "assistant": - from_str = "assistant" - else: - raise ValueError(f"Unsupported role {from_str.lower()}") - rejected.append_message(from_str, data_point["rejected"][round]["content"]) + for round in range(len(rejected_continuation)): + if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]: + raise ValueError( + f"Message should iterate between user and assistant and starts with a \ + line from the user. Got the following data:\n{rejected_continuation}" + ) + rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"]) ( chosen_input_ids, @@ -356,48 +276,22 @@ def tokenize_rlhf( rejected_loss_mask, rejected_label_decode, ) = (None, None, None, None, None, None) - if ( - len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0]) - <= max_length - 1 - and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0]) - <= max_length - 1 - ): - chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context) - (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = ( - chosen_data_packed["input_ids"], - chosen_data_packed["loss_mask"], - chosen_data_packed["label_decode"], - ) - rejected_data_packed = apply_rlhf_data_format( - rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True - ) - (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = ( - rejected_data_packed["input_ids"], - rejected_data_packed["loss_mask"], - rejected_data_packed["label_decode"], - ) + chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer) + (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = ( + chosen_data_packed["input_ids"], + chosen_data_packed["loss_mask"], + chosen_data_packed["label_decode"], + ) - # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long - if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask): - return dict( - chosen_input_ids=None, - chosen_loss_mask=None, - chosen_label_decode=None, - rejected_input_ids=None, - rejected_loss_mask=None, - rejected_label_decode=None, - ) + rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer) + (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = ( + rejected_data_packed["input_ids"], + rejected_data_packed["loss_mask"], + rejected_data_packed["label_decode"], + ) - return { - "chosen_input_ids": chosen_input_ids, - "chosen_loss_mask": chosen_loss_mask, - "chosen_label_decode": chosen_label_decode, - "rejected_input_ids": rejected_input_ids, - "rejected_loss_mask": rejected_loss_mask, - "rejected_label_decode": rejected_label_decode, - } - else: + if len(chosen_input_ids) > max_length or len(rejected_input_ids) > max_length: return dict( chosen_input_ids=None, chosen_loss_mask=None, @@ -406,3 +300,81 @@ def tokenize_rlhf( rejected_loss_mask=None, rejected_label_decode=None, ) + # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long + if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0: + return dict( + chosen_input_ids=None, + chosen_loss_mask=None, + chosen_label_decode=None, + rejected_input_ids=None, + rejected_loss_mask=None, + rejected_label_decode=None, + ) + + return { + "chosen_input_ids": chosen_input_ids, + "chosen_loss_mask": chosen_loss_mask, + "chosen_label_decode": chosen_label_decode, + "rejected_input_ids": rejected_input_ids, + "rejected_loss_mask": rejected_loss_mask, + "rejected_label_decode": rejected_label_decode, + } + + +def tokenize_kto( + data_point: Dict[str, str], + tokenizer: PreTrainedTokenizer, + conversation_template: Conversation = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + Tokenize a dataset for KTO training + The raw input data is conversation that have the following format + { + "prompt": [{"from": "user", "content": "xxx"}...], + "completion": {"from": "assistant", "content": "xxx"}, + "label": true/false + } + It returns three fields + The context, which contain the query and the assistant start, + the completion, which only contains the assistance's answer, + and a binary label, which indicates if the sample is prefered or not + """ + prompt = data_point["prompt"] + completion = data_point["completion"] + template = deepcopy(conversation_template) + template.clear() + + if prompt[0].get("from", None) != "user": + raise ValueError("conversation should start with user") + if completion.get("from", None) != "assistant": + raise ValueError("conversation should end with assistant") + + for mess in prompt: + if mess.get("from", None) == "user": + template.append_message("user", mess["content"]) + elif mess.get("from", None) == "assistant": + template.append_message("assistant", mess["content"]) + else: + raise ValueError(f"Unsupported role {mess.get('from', None)}") + generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True) + template.append_message("assistant", completion["content"]) + full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False) + tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"] + if len(tokenized_full_prompt) + 1 > max_length: + return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None) + tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"] + tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :] + tokenized_completion = deepcopy(tokenized_completion) + if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id: + tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt + decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False) + decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False) + + return { + "prompt": tokenized_generation_prompt, + "completion": tokenized_completion, + "label": data_point["label"], + "input_id_decode": decoded_full_prompt, + "completion_decode": decoded_completion, + } diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py index f41a4d772..42c3191db 100755 --- a/applications/ColossalChat/coati/dataset/utils.py +++ b/applications/ColossalChat/coati/dataset/utils.py @@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s return -1 -def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]): +def tokenize_and_concatenate( + tokenizer: PreTrainedTokenizer, + text: List[str], + require_loss: List[bool], + max_length: int, + discard_non_loss_tokens_at_tail: bool = True, +): """ Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs. @@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization. text (List[str]): The list of texts to tokenize. require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation. + max_length: used to truncate the input ids + discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail + + if the first round has already exeeded max length + - if the user query already exeeded max length, discard the sample + - if only the first assistant response exeeded max length, truncate the response to fit the max length + else keep the first several complete rounds of the conversations until max length is reached Returns: Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids, @@ -106,10 +119,18 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re loss_ends = [] for s, r in zip(text, require_loss): tokenized = tokenizer(s, add_special_tokens=False)["input_ids"] - if r: - loss_starts.append(len(input_ids)) - loss_ends.append(len(input_ids) + len(tokenized)) - input_ids.extend(tokenized) + if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0: + if r: + loss_starts.append(len(input_ids)) + loss_ends.append(len(input_ids) + len(tokenized)) + input_ids.extend(tokenized) + if max_length and loss_starts[0] >= max_length: + return None, None, None + if discard_non_loss_tokens_at_tail: + input_ids = input_ids[: loss_ends[-1]] + if max_length: + input_ids = input_ids[:max_length] + loss_ends[-1] = min(max_length, loss_ends[-1]) return input_ids, loss_starts, loss_ends @@ -125,6 +146,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s content_length = ( prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur ) + # if the tokenized content start with a leading space, we want to keep it in loss calculation + # e.g., Assistant: I am saying... + # if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation + # e.g., + # Assistant: # '\n' as line breaker + # I am saying... if prompt[first_occur - 1] != " ": chunks.append(prompt[start_idx:first_occur]) chunks.append(prompt[first_occur : first_occur + content_length]) diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py index 14073207f..fba0949e3 100755 --- a/applications/ColossalChat/coati/models/__init__.py +++ b/applications/ColossalChat/coati/models/__init__.py @@ -1,8 +1,8 @@ from .base import BaseModel from .critic import Critic from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn -from .lora import convert_to_lora_module -from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss +from .lora import LoraConfig, convert_to_lora_module, lora_manager +from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from .reward_model import RewardModel from .utils import disable_dropout @@ -14,9 +14,11 @@ __all__ = [ "ValueLoss", "LogSigLoss", "LogExpLoss", + "LoraConfig", + "lora_manager", "convert_to_lora_module", "DpoLoss", - "generate", + "KTOLoss" "generate", "generate_streaming", "disable_dropout", "update_model_kwargs_fn", diff --git a/applications/ColossalChat/coati/models/base.py b/applications/ColossalChat/coati/models/base.py index fcea9414b..cfdffdf28 100755 --- a/applications/ColossalChat/coati/models/base.py +++ b/applications/ColossalChat/coati/models/base.py @@ -42,7 +42,6 @@ class BaseModel(nn.Module): out = self.model(dummy_input) self.last_hidden_state_size = out.last_hidden_state.shape[-1] self.model = self.model.cpu() - # print("self.last_hidden_state_size: ",self.last_hidden_state_size) def resize_token_embeddings(self, *args, **kwargs): """ diff --git a/applications/ColossalChat/coati/models/lora.py b/applications/ColossalChat/coati/models/lora.py index 9553b00ff..aa5f6ecf8 100755 --- a/applications/ColossalChat/coati/models/lora.py +++ b/applications/ColossalChat/coati/models/lora.py @@ -5,10 +5,11 @@ LORA utils import dataclasses import math import warnings -from typing import Optional +from typing import List, Optional, Union import loralib as lora import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -18,148 +19,349 @@ logger = get_dist_logger() @dataclasses.dataclass -class LoRAManager: - merge_weights: bool = False +class LoraManager: + able_to_merge: bool = True -LORA_MANAGER = LoRAManager() +lora_manager = LoraManager() -class LoraLinear(lora.LoRALayer, nn.Module): - """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" +@dataclasses.dataclass +class LoraConfig: + r: int = 0 + lora_alpha: int = 32 + linear_lora_dropout: float = 0.1 + embedding_lora_dropout: float = 0.0 + lora_train_bias: str = "none" + lora_initialization_method: str = "kaiming_uniform" + target_modules: List = None + @classmethod + def from_file(cls, config_file: str): + import json + + with open(config_file, "r") as f: + config = json.load(f) + return cls(**config) + + +class LoraBase(lora.LoRALayer, nn.Module): def __init__( self, - weight: nn.Parameter, - bias: Optional[nn.Parameter], r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - fan_in_fan_out: bool = False, + lora_alpha: int = 32, + lora_dropout: float = 0.1, + lora_initialization_method: str = "kaiming_uniform", ): nn.Module.__init__(self) lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) - self.weight = weight - self.bias = bias - - out_features, in_features = weight.shape - self.in_features = in_features - self.out_features = out_features - - self.fan_in_fan_out = fan_in_fan_out - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) - self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - if fan_in_fan_out: - self.weight.data = self.weight.data.T + self.r = r + self.lora_alpha = lora_alpha + self.lora_dropout = nn.Dropout(lora_dropout) + self.merged = False + self.lora_initialization_method = lora_initialization_method + self.weight = None + self.bias = None + self.lora_A = None + self.lora_B = None def reset_parameters(self): if hasattr(self, "lora_A"): - # Initialize A with the default values for nn.Linear and set B to zero. - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) + if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != ( + self.out_features, + self.in_features, + ): + # Initialize A with the default values for nn.Linear and set B to zero. + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + elif self.lora_initialization_method == "PiSSA": + # PiSSA method in this paper: https://arxiv.org/abs/2404.02948 + # Assume the SVD of the original weights is W = USV^T + # Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W + # Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively + # self.scaling = 1. + # SVD + U, S, Vh = torch.svd_lowrank( + self.weight.to(torch.float32).data, self.r, niter=4 + ) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features] + # weight_backup = self.weight.clone() + + # Initialize A, B + S = S / self.scaling + self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous() + self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous() + # Initialize weight + # To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :] + self.weight.data = ( + ((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype) + ) + self.lora_A.requires_grad = True + self.lora_B.requires_grad = True + else: + raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}") def train(self, mode: bool = True): """ This function runs when model.train() is invoked. It is used to prepare the linear layer for training """ - def T(w): - return w.T if self.fan_in_fan_out else w - self.training = mode - if LORA_MANAGER.merge_weights: - if mode and self.merged: - warnings.warn("Invoke module.train() would unmerge LoRA weights.") - raise NotImplementedError("LoRA unmerge is not tested.") - # Make sure that the weights are not merged - if self.r > 0: - if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): - # FIXME(csric): temporary fix - self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) - self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) - self.reset_parameters() - else: - self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - self.merged = False - elif not mode and not self.merged: - warnings.warn("Invoke module.eval() would merge LoRA weights.") - # Merge the weights and mark it - if self.r > 0: - self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - delattr(self, "lora_A") - delattr(self, "lora_B") - self.merged = True + if mode and self.merged: + warnings.warn("Invoke module.train() would unmerge LoRA weights.") + raise NotImplementedError("LoRA unmerge is not tested.") + elif not mode and not self.merged and lora_manager.able_to_merge: + warnings.warn("Invoke module.eval() would merge LoRA weights.") + # Merge the weights and mark it + if self.r > 0: + self.weight.data += self.lora_B @ self.lora_A * self.scaling + delattr(self, "lora_A") + delattr(self, "lora_B") + self.merged = True return self - def forward(self, x: torch.Tensor): - def T(w): - return w.T if self.fan_in_fan_out else w +class LoraLinear(LoraBase): + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" + + def __init__( + self, + weight: nn.Parameter, + bias: Union[nn.Parameter, bool], + r: int = 0, + lora_alpha: int = 32, + lora_dropout: float = 0.0, + lora_initialization_method: str = "kaiming_uniform", + ): + super().__init__( + r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method + ) + self.weight = weight + self.bias = bias + if bias is True: + self.bias = nn.Parameter(torch.zeros(weight.shape[0])) + if bias is not None: + self.bias.requires_grad = True + + out_features, in_features = weight.shape + self.in_features = in_features + self.out_features = out_features + assert lora_initialization_method in ["kaiming_uniform", "PiSSA"] + self.lora_initialization_method = lora_initialization_method + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.randn((r, in_features))) + self.lora_B = nn.Parameter(torch.randn((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + + def forward(self, x: torch.Tensor): if self.r > 0 and not self.merged: - result = F.linear(x, T(self.weight), bias=self.bias) - if self.r > 0: - result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling + result = F.linear(x, self.weight, bias=self.bias) + result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling return result else: - return F.linear(x, T(self.weight), bias=self.bias) + return F.linear(x, self.weight, bias=self.bias) -def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: +class LoraEmbedding(LoraBase): + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" + + def __init__( + self, + weight: nn.Parameter, + r: int = 0, + lora_alpha: int = 32, + lora_dropout: float = 0.1, + num_embeddings: int = None, + embedding_dim: int = None, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + lora_initialization_method: str = "kaiming_uniform", + ): + super().__init__( + r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method + ) + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.weight = weight + + in_features, out_features = num_embeddings, embedding_dim + self.in_features = in_features + self.out_features = out_features + assert lora_initialization_method in ["kaiming_uniform", "PiSSA"] + self.lora_initialization_method = lora_initialization_method + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.randn((r, in_features))) + self.lora_B = nn.Parameter(torch.randn((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + # reset parameters + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) + + def _embed(self, x: torch.Tensor, weight) -> torch.Tensor: + return F.embedding( + x, + weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + + def forward(self, x: torch.Tensor): + base_embedding = self._embed(x, self.weight) + # base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing + if self.r > 0 and not self.merged: + lora_A_embedding = self._embed(x, self.lora_A.t()) + embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling + return embedding + else: + return base_embedding + + def train(self, mode: bool = True): + """ + This function runs when model.train() is invoked. It is used to prepare the linear layer for training + """ + + self.training = mode + if mode and self.merged: + warnings.warn("Invoke module.train() would unmerge LoRA weights.") + raise NotImplementedError("LoRA unmerge is not tested.") + elif not mode and not self.merged and lora_manager.able_to_merge: + warnings.warn("Invoke module.eval() would merge LoRA weights.") + # Merge the weights and mark it + if self.r > 0: + self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling + delattr(self, "lora_A") + delattr(self, "lora_B") + self.merged = True + + return self + + +def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear: """ Wraps a linear layer with LoRA functionality. Args: linear (nn.Linear): The linear layer to be wrapped. lora_rank (int): The rank of the LoRA decomposition. + lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora". + lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA". Returns: LoraLinear: The wrapped linear layer with LoRA functionality. """ assert ( - lora_rank <= linear.in_features - ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" - lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank) + lora_config.r <= linear.in_features + ), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})" + bias = None + if lora_config.lora_train_bias in ["all", "lora"]: + bias = linear.bias + if bias is None: + bias = True + lora_linear = LoraLinear( + linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method + ) return lora_linear -def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: +def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None: """ Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form. Args: module (nn.Module): The module to convert to LoRA form. lora_rank (int): The rank of the LoRA approximation. + lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora". + parent_name (str): The name of the parent module. + lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA". Returns: None """ for name, child in module.named_children(): if isinstance(child, nn.Linear): - setattr(module, name, _lora_linear_wrapper(child, lora_rank)) + if lora_config.target_modules is None or any( + [name in target_module for target_module in lora_config.target_modules] + ): + if dist.is_initialized() and dist.get_rank() == 0: + logger.info(f"Converting {parent_name}.{name} to LoRA") + setattr(module, name, _lora_linear_wrapper(child, lora_config)) + elif isinstance(child, nn.Embedding): + if lora_config.target_modules is None or any( + [name in target_module for target_module in lora_config.target_modules] + ): + if dist.is_initialized() and dist.get_rank() == 0: + logger.info(f"Converting {parent_name}.{name} to LoRA") + setattr( + module, + name, + LoraEmbedding( + child.weight, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.embedding_lora_dropout, + num_embeddings=child.num_embeddings, + embedding_dim=child.embedding_dim, + padding_idx=child.padding_idx, + max_norm=child.max_norm, + norm_type=child.norm_type, + scale_grad_by_freq=child.scale_grad_by_freq, + sparse=child.sparse, + lora_initialization_method=lora_config.lora_initialization_method, + ), + ) else: - _convert_to_lora_recursively(child, lora_rank) + _convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config) -def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module: +def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module: """Convert a torch.nn.Module to a LoRA module. Args: module (nn.Module): The module to convert. lora_rank (int): LoRA rank. + lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora". + lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA". Returns: nn.Module: The converted module. """ - if lora_rank <= 0: + if lora_config.r <= 0: return module - _convert_to_lora_recursively(module, lora_rank) - lora.mark_only_lora_as_trainable(module, lora_train_bias) + # make all parameter not trainable, if lora_train_bias is "all", set bias to trainable + total_parameter_size = 0 + for name, p in module.named_parameters(): + p.requires_grad = False + if "bias" in name and lora_config.lora_train_bias == "all": + p.requires_grad = True + total_parameter_size += p.numel() + _convert_to_lora_recursively(module, "", lora_config) + trainable_parameter_size = 0 + for name, p in module.named_parameters(): + if p.requires_grad == True: + trainable_parameter_size += p.numel() + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%" + ) return module diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index e411dded1..840cca074 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -5,6 +5,7 @@ loss functions from typing import Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn from .utils import masked_mean @@ -89,11 +90,22 @@ class DpoLoss(nn.Module): """ Dpo loss Details: https://arxiv.org/pdf/2305.18290.pdf + + SimPO loss: + Details: https://arxiv.org/pdf/2405.14734.pdf """ - def __init__(self, beta: float = 0.1): + def __init__(self, beta: float = 0.1, gamma: float = 0.0): + """ + Args: + beta: The temperature parameter in the DPO paper. + gamma: The margin parameter in the SimPO paper. + length_normalization: Whether to normalize the loss by the length of chosen and rejected responses. + Refer to the length normalization in the SimPO paper + """ super().__init__() self.beta = beta + self.gamma = gamma def forward( self, @@ -104,7 +116,7 @@ class DpoLoss(nn.Module): chosen_mask: torch.Tensor, reject_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute the DPO loss for a batch of policy and reference model log probabilities. + """Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities. # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328 @@ -113,6 +125,8 @@ class DpoLoss(nn.Module): logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,) + reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,) Returns: A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). @@ -127,13 +141,12 @@ class DpoLoss(nn.Module): if len(logprob_ref_chosen.shape) == 2: ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1) else: - ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze() + ref_logratios = logprob_ref_chosen - logprob_ref_reject else: # If no reference model is provided ref_logratios = 0.0 - pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1) - logits = pi_logratios - ref_logratios + logits = pi_logratios - ref_logratios - self.gamma / self.beta losses = -torch.nn.functional.logsigmoid(self.beta * logits) # Calculate rewards for logging @@ -168,3 +181,93 @@ class LogExpLoss(nn.Module): def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() return loss + + +class OddsRatioLoss(nn.Module): + """ + Odds Ratio Loss in ORPO + Details: https://arxiv.org/pdf/2403.07691 + """ + + def forward( + self, + chosen_logp: torch.Tensor, + reject_logp: torch.Tensor, + chosen_loss_mask: torch.Tensor, + reject_loss_mask: torch.Tensor, + ) -> torch.Tensor: + chosen_logp = chosen_logp.to(dtype=torch.float32) + reject_logp = reject_logp.to(dtype=torch.float32) + chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001) + chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask) + reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001) + reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask) + log_odds_ratio = chosen_odds_masked - reject_odds_masked + ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio)) + return ratio.to(dtype=torch.bfloat16), log_odds_ratio + + +class KTOLoss(nn.Module): + def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0): + """ + Args: + beta: The temperature parameter in the KTO paper. + desirable_weight: The weight for the desirable responses. + undesirable_weight: The weight for the undesirable + """ + super().__init__() + self.beta = beta + self.desirable_weight = desirable_weight + self.undesirable_weight = undesirable_weight + + def forward( + self, + chosen_logps: torch.Tensor, + rejected_logps: torch.Tensor, + kl_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ref_kl_logps: torch.Tensor, + ): + """ + Reference: + https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585 + + Compute the KTO loss for a batch of policy and reference model log probabilities. + Args: + chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + kl_logps: KL divergence of the policy model. Shape: (batch_size,) + ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,) + beta: The temperature parameter in the DPO paper. + desirable_weight: The weight for the desirable responses. + undesirable_weight: The weight for the undesirable responses. + + Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306 + """ + kl = (kl_logps - ref_kl_logps).mean().detach() + # all gather + dist.all_reduce(kl, op=dist.ReduceOp.SUM) + kl = (kl / dist.get_world_size()).clamp(min=0) + + if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0: + chosen_logratios = chosen_logps - ref_chosen_logps + chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl)) + chosen_rewards = self.beta * chosen_logratios.detach() + else: + chosen_losses = torch.Tensor([]).to(kl_logps.device) + chosen_rewards = torch.Tensor([]).to(kl_logps.device) + + if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0: + rejected_logratios = rejected_logps - ref_rejected_logps + rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios)) + rejected_rewards = self.beta * rejected_logratios.detach() + else: + rejected_losses = torch.Tensor([]).to(kl_logps.device) + rejected_rewards = torch.Tensor([]).to(kl_logps.device) + + losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean() + + return losses, chosen_rewards, rejected_rewards, kl diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py index ce672534c..8ed8d3401 100755 --- a/applications/ColossalChat/coati/models/utils.py +++ b/applications/ColossalChat/coati/models/utils.py @@ -89,7 +89,9 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch return mean -def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor: +def calc_masked_log_probs( + logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor, length_normalization: bool = False +) -> torch.Tensor: """ Calculate the masked log probabilities for a given sequence of logits. @@ -103,7 +105,11 @@ def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mas """ # logits are probabilities of the next token, so we shift them to the left by one log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) - return log_probs * mask + + if not length_normalization: + return log_probs * mask + else: + return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01) def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py index 2eff8ca76..6d0900153 100755 --- a/applications/ColossalChat/coati/trainer/__init__.py +++ b/applications/ColossalChat/coati/trainer/__init__.py @@ -1,7 +1,18 @@ from .base import OLTrainer, SLTrainer from .dpo import DPOTrainer +from .kto import KTOTrainer +from .orpo import ORPOTrainer from .ppo import PPOTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer -__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"] +__all__ = [ + "SLTrainer", + "OLTrainer", + "RewardModelTrainer", + "SFTTrainer", + "PPOTrainer", + "DPOTrainer", + "ORPOTrainer", + "KTOTrainer", +] diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index cbe7d7ca8..c7ef2be8f 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -2,6 +2,7 @@ Dpo trainer """ +import os from typing import Any, Optional import torch @@ -25,7 +26,7 @@ from .utils import is_rank_0, to_device class DPOTrainer(SLTrainer): """ - Trainer for PPO algorithm. + Trainer for DPO algorithm. Args: actor (Actor): the actor model in ppo algorithm @@ -53,6 +54,8 @@ class DPOTrainer(SLTrainer): tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, beta: float = 0.1, + gamma: float = 0.0, + length_normalization: bool = False, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -63,7 +66,7 @@ class DPOTrainer(SLTrainer): self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer - self.actor_loss_fn = DpoLoss(beta) + self.actor_loss_fn = DpoLoss(beta, gamma) self.save_interval = save_interval self.coordinator = coordinator self.save_dir = save_dir @@ -71,6 +74,7 @@ class DPOTrainer(SLTrainer): self.accumulation_steps = accumulation_steps self.device = get_current_device() self.accumulative_meter = AccumulativeMeanMeter() + self.length_normalization = length_normalization def _before_fit( self, @@ -131,18 +135,21 @@ class DPOTrainer(SLTrainer): batch["reject_attention_mask"], batch["reject_loss_mask"], ) - reject_loss_mask[:, -1] = False batch_size = chosen_input_ids.size()[0] actor_all_logits = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), - )["logits"].to(torch.float32) + )["logits"] actor_chosen_logits = actor_all_logits[:batch_size] actor_reject_logits = actor_all_logits[batch_size:] - logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) + logprob_actor_chosen = calc_masked_log_probs( + actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization + ) - logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) + logprob_actor_reject = calc_masked_log_probs( + actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization + ) if self.ref_model is not None: self.ref_model.eval() @@ -150,14 +157,14 @@ class DPOTrainer(SLTrainer): ref_all_logits = self.ref_model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), - )["logits"].to(torch.float32) + )["logits"] ref_chosen_logits = ref_all_logits[:batch_size] ref_reject_logits = ref_all_logits[batch_size:] logprob_ref_chosen = calc_masked_log_probs( - ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:] + ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization ) logprob_ref_reject = calc_masked_log_probs( - ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] + ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization ) else: logprob_ref_chosen = None @@ -219,7 +226,7 @@ class DPOTrainer(SLTrainer): ) self.accumulative_meter.reset() - if (self.num_train_step + 1) % self.save_interval == 0: + if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0: # save checkpoint self.coordinator.print_on_master("\nStart saving model checkpoint with running states") save_checkpoint( @@ -283,16 +290,16 @@ class DPOTrainer(SLTrainer): actor_all_logits = self.model( torch.cat([chosen_input_ids, reject_input_ids]), torch.cat([chosen_attention_mask, reject_attention_mask]), - )["logits"].to(torch.float32) + )["logits"] actor_chosen_logits = actor_all_logits[:batch_size] actor_reject_logits = actor_all_logits[batch_size:] logprob_actor_chosen = calc_masked_log_probs( - actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:] + actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization ) logprob_actor_reject = calc_masked_log_probs( - actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] + actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization ) self.ref_model.eval() @@ -300,11 +307,15 @@ class DPOTrainer(SLTrainer): ref_all_logits = self.ref_model( torch.cat([chosen_input_ids, reject_input_ids]), torch.cat([chosen_attention_mask, reject_attention_mask]), - )["logits"].to(torch.float32) + )["logits"] ref_chosen_logits = ref_all_logits[:batch_size] ref_reject_logits = ref_all_logits[batch_size:] - logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) - logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) + logprob_ref_chosen = calc_masked_log_probs( + ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization + ) + logprob_ref_reject = calc_masked_log_probs( + ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization + ) losses, chosen_rewards, rejected_rewards = self.actor_loss_fn( logprob_actor_chosen, @@ -314,7 +325,7 @@ class DPOTrainer(SLTrainer): chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:], ) - reward_accuracies = (chosen_rewards > rejected_rewards).float() + reward_accuracies = (chosen_rewards > rejected_rewards).float().mean() loss = losses.mean() loss_mean = all_reduce_mean(tensor=loss) chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) @@ -333,4 +344,7 @@ class DPOTrainer(SLTrainer): for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]: msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" self.coordinator.print_on_master(msg) + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py new file mode 100755 index 000000000..8ab0bc66b --- /dev/null +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -0,0 +1,318 @@ +""" +KTO trainer +""" + +import os +from typing import Any, Optional + +import torch +import torch.distributed +from coati.models.loss import KTOLoss +from coati.models.utils import calc_masked_log_probs +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import trange +from transformers import PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class KTOTrainer(SLTrainer): + """ + Trainer for KTO algorithm. + + Args: + actor (Actor): the actor model in ppo algorithm + ref_model (Critic): the reference model in ppo algorithm + booster (Strategy): the strategy to use for training + actor_optim (Optimizer): the optimizer to use for actor model + actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model + tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding + max_epochs (int, defaults to 1): the max number of epochs to train + accumulation_steps (int): the number of steps to accumulate gradients + start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint + save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning + save_dir (str): the directory to save checkpoints + coordinator (DistCoordinator): the coordinator to use for distributed logging + beta (float, defaults to 0.1): the beta parameter in kto loss + desirable_weight (float, defaults to 1.0): the weight for desirable reward + undesirable_weight (float, defaults to 1.0): the weight for undesirable reward + """ + + def __init__( + self, + actor: Any, + ref_model: Any, + booster: Booster, + actor_optim: Optimizer, + actor_lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + max_epochs: int = 1, + beta: float = 0.1, + desirable_weight: float = 1.0, + undesirable_weight: float = 1.0, + accumulation_steps: int = 1, + start_epoch: int = 0, + save_interval: int = 0, + save_dir: str = None, + coordinator: DistCoordinator = None, + ) -> None: + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + self.ref_model = ref_model + self.actor_scheduler = actor_lr_scheduler + self.tokenizer = tokenizer + self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight) + self.save_interval = save_interval + self.coordinator = coordinator + self.save_dir = save_dir + self.num_train_step = 0 + self.accumulation_steps = accumulation_steps + self.device = get_current_device() + self.accumulative_meter = AccumulativeMeanMeter() + self.desirable_weight = desirable_weight + self.undesirable_weight = undesirable_weight + self.beta = beta + + def _before_fit( + self, + train_preference_dataloader: DataLoader = None, + eval_preference_dataloader: DataLoader = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.train_dataloader = train_preference_dataloader + self.eval_dataloader = eval_preference_dataloader + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-kto", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "kto") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _train(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + self.model.train() + self.accumulative_meter.reset() + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + (input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = ( + batch["input_ids"], + batch["attention_mask"], + batch["loss_mask"], + batch["label"], + batch["kl_input_ids"], + batch["kl_attention_mask"], + batch["kl_loss_mask"], + ) + batch_size = input_ids.size()[0] + + # actor logits + with torch.no_grad(): + # calculate KL term with KT data + kl_logits = self.model( + input_ids=kl_input_ids, + attention_mask=kl_attention_mask, + )["logits"] + + logits = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + )["logits"] + + logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1) + kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1) + chosen_index = [i for i in range(batch_size) if label[i] == 1] + rejected_index = [i for i in range(batch_size) if label[i] == 0] + chosen_logprob = logprob[chosen_index] + rejected_logprob = logprob[rejected_index] + with torch.no_grad(): + ref_kl_logits = self.ref_model( + input_ids=kl_input_ids, + attention_mask=kl_attention_mask, + )["logits"] + ref_logits = self.ref_model( + input_ids=input_ids, + attention_mask=attention_mask, + )["logits"] + + ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1) + ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1) + ref_chosen_logprob = ref_logprob[chosen_index] + ref_rejected_logprob = ref_logprob[rejected_index] + + loss, chosen_rewards, rejected_rewards, kl = self.kto_loss( + chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob + ) + + self.booster.backward(loss=loss, optimizer=self.optimizer) + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: + self.optimizer.step() + self.optimizer.zero_grad() + self.actor_scheduler.step() + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean()) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean()) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) + + if i % self.accumulation_steps == self.accumulation_steps - 1: + self.num_train_step += 1 + step_bar.update() + # logging + if self.writer and is_rank_0(): + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + ) + self.writer.add_scalar( + "train/rejected_rewards", + self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/margin", + self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.accumulative_meter.reset() + + if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0: + # save checkpoint + self.coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.actor_scheduler, + epoch=epoch, + step=i + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" + ) + + step_bar.close() + + def _eval(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.model.eval() + self.accumulative_meter.reset() + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + (input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = ( + batch["input_ids"], + batch["attention_mask"], + batch["loss_mask"], + batch["label"], + batch["kl_input_ids"], + batch["kl_attention_mask"], + batch["kl_loss_mask"], + ) + batch_size = input_ids.size()[0] + + # actor logits + with torch.no_grad(): + # calculate KL term with KT data + kl_logits = self.model( + input_ids=kl_input_ids, + attention_mask=kl_attention_mask, + )["logits"] + + logits = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + )["logits"] + + logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1) + kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1) + chosen_index = [i for i in range(batch_size) if label[i] == 1] + rejected_index = [i for i in range(batch_size) if label[i] == 0] + chosen_logprob = logprob[chosen_index] + rejected_logprob = logprob[rejected_index] + with torch.no_grad(): + ref_kl_logits = self.ref_model( + input_ids=kl_input_ids, + attention_mask=kl_attention_mask, + )["logits"] + + ref_logits = self.ref_model( + input_ids=input_ids, + attention_mask=attention_mask, + )["logits"] + + ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1) + ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1) + ref_chosen_logprob = ref_logprob[chosen_index] + ref_rejected_logprob = ref_logprob[rejected_index] + + loss, chosen_rewards, rejected_rewards, kl = self.kto_loss( + chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob + ) + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean()) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean()) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) + self.accumulative_meter.add( + "margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item() + ) + step_bar.update() + msg = "Evaluation Result:\n" + for tag in ["loss", "chosen_rewards", "rejected_rewards", "margin"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py new file mode 100644 index 000000000..b039da4af --- /dev/null +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -0,0 +1,314 @@ +""" +Orpo trainer +""" + +import os +from typing import Any, Optional + +import torch +from coati.models.loss import OddsRatioLoss +from coati.models.utils import calc_masked_log_probs +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import trange +from transformers import PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class ORPOTrainer(SLTrainer): + """ + Trainer for ORPO algorithm. + + Args: + actor (Actor): the actor model in ppo algorithm + booster (Strategy): the strategy to use for training + actor_optim (Optimizer): the optimizer to use for actor model + actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model + tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding + max_epochs (int, defaults to 1): the max number of epochs to train + lam (float, defaults to 0.1): the lambda parameter in ORPO loss + accumulation_steps (int): the number of steps to accumulate gradients + start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint + save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning + save_dir (str): the directory to save checkpoints + coordinator (DistCoordinator): the coordinator to use for distributed logging + """ + + def __init__( + self, + actor: Any, + booster: Booster, + actor_optim: Optimizer, + actor_lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + max_epochs: int = 1, + lam: float = 0.1, + accumulation_steps: int = 1, + start_epoch: int = 0, + save_interval: int = 0, + save_dir: str = None, + coordinator: DistCoordinator = None, + ) -> None: + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + self.actor_scheduler = actor_lr_scheduler + self.tokenizer = tokenizer + self.odds_ratio_loss_fn = OddsRatioLoss() + self.save_interval = save_interval + self.coordinator = coordinator + self.save_dir = save_dir + self.num_train_step = 0 + self.lam = lam + self.accumulation_steps = accumulation_steps + self.device = get_current_device() + self.accumulative_meter = AccumulativeMeanMeter() + + def _before_fit( + self, + train_preference_dataloader: DataLoader = None, + eval_preference_dataloader: DataLoader = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.train_dataloader = train_preference_dataloader + self.eval_dataloader = eval_preference_dataloader + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "orpo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _train(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + self.model.train() + self.accumulative_meter.reset() + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["chosen_loss_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + batch["reject_loss_mask"], + ) + batch_size = chosen_input_ids.size()[0] + actor_out = self.model( + input_ids=torch.cat([chosen_input_ids, reject_input_ids]), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + labels=torch.cat( + [chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100] + ), + ) + torch.autograd.set_detect_anomaly(True) + actor_all_logits = actor_out["logits"].to(torch.float32) + actor_chosen_logits = actor_all_logits[:batch_size] + actor_reject_logits = actor_all_logits[batch_size:] + logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) + + logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) + # label_chosen[chosen_loss_mask[:, 1:] == 0] = -100 + chosen_nll = actor_out["loss"] + odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( + logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] + ) + loss = chosen_nll - odds_ratio_loss * self.lam + step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}") + + self.booster.backward(loss=loss, optimizer=self.optimizer) + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: + self.optimizer.step() + self.optimizer.zero_grad() + self.actor_scheduler.step() + + chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:]) + rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:]) + reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0) + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) + reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item()) + self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) + + if i % self.accumulation_steps == self.accumulation_steps - 1: + self.num_train_step += 1 + step_bar.update() + # logging + if self.writer and is_rank_0(): + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + ) + self.writer.add_scalar( + "train/rejected_rewards", + self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/margin", + self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/accuracy", + self.accumulative_meter.get("accuracy"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/log_odds_ratio", + self.accumulative_meter.get("log_odds_ratio"), + self.num_train_step, + ) + self.accumulative_meter.reset() + + if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0: + # save checkpoint + self.coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.actor_scheduler, + epoch=epoch, + step=i + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" + ) + + step_bar.close() + + def _eval(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.model.eval() + self.coordinator.print_on_master("\nStart evaluation...") + + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + + self.accumulative_meter.reset() + + with torch.no_grad(): + for i, batch in enumerate(self.eval_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["chosen_loss_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + batch["reject_loss_mask"], + ) + batch_size = chosen_input_ids.size()[0] + actor_out = self.model( + input_ids=torch.cat([chosen_input_ids, reject_input_ids]), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + labels=torch.cat( + [chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100] + ), + ) + torch.autograd.set_detect_anomaly(True) + actor_all_logits = actor_out["logits"].to(torch.float32) + actor_chosen_logits = actor_all_logits[:batch_size] + actor_reject_logits = actor_all_logits[batch_size:] + logprob_actor_chosen = calc_masked_log_probs( + actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:] + ) + + logprob_actor_reject = calc_masked_log_probs( + actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] + ) + chosen_nll = actor_out["loss"] + odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( + logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] + ) + loss = chosen_nll - odds_ratio_loss * self.lam + step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}") + + chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:]) + rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:]) + reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0) + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) + reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item()) + self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) + + msg = "Evaluation Result:\n" + for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index 0fb714a62..b9e84ef55 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -237,6 +237,7 @@ class RewardModelTrainer(SLTrainer): + f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n" ) self.coordinator.print_on_master(msg) + os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: f.write(msg) step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index c95f5b65a..c09d61034 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -102,6 +102,7 @@ class SFTTrainer(SLTrainer): batch_size = batch["input_ids"].size(0) outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss = outputs.loss + self.booster.backward(loss=loss, optimizer=self.optimizer) loss_mean = all_reduce_mean(tensor=loss) @@ -113,6 +114,7 @@ class SFTTrainer(SLTrainer): self.optimizer.zero_grad() self.scheduler.step() + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) if self.writer: self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) @@ -165,6 +167,7 @@ class SFTTrainer(SLTrainer): for tag in ["loss"]: msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" self.coordinator.print_on_master(msg) + os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: f.write(msg) step_bar.close() diff --git a/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json new file mode 100644 index 000000000..58941a591 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json @@ -0,0 +1,9 @@ +{ + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 151645, + 151643 + ], + "end_of_assistant": "<|im_end|>" +} diff --git a/applications/ColossalChat/config/conversation_template/tiny-llama.json b/applications/ColossalChat/config/conversation_template/tiny-llama.json new file mode 100644 index 000000000..59196159f --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/tiny-llama.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ], + "end_of_assistant": "" +} diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index a29fc7508..b749f197e 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -9,6 +9,7 @@ - [Install Requirements](#install-requirements) - [Get Start with ColossalRun](#get-start-with-colossalrun) - [Training Configuration](#training-configuration) + - [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft) - [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning) - [Step 1: Data Collection](#step-1-data-collection) - [Step 2: Preprocessing](#step-2-preprocessing) @@ -29,6 +30,9 @@ - [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization) - [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning) - [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training) + - [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization) + - [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto) + - [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization) - [List of Supported Models](#list-of-supported-models) - [Hardware Requirements](#hardware-requirements) - [Inference example](#inference-example) @@ -45,9 +49,6 @@ pip install -r requirements.txt ``` - - - ## Get Start with ColossalRun @@ -81,8 +82,6 @@ Make sure the master node can access all nodes (including itself) by ssh without This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more details regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins). - -
Gemini (Zero3) @@ -374,35 +373,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
-
Low Rank Adaption - - -Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources. - - -To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64). -``` -colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ - --pretrain $PRETRAINED_MODEL_PATH \ - --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ - --dataset ${dataset[@]} \ - --save_interval 5000 \ - --save_path $SAVE_DIR \ - --config_file $CONFIG_FILE \ - --plugin zero2_cpu \ - --batch_size 4 \ - --max_epochs 1 \ - --accumulation_steps 4 \ - --lr 2e-5 \ - --max_len 2048 \ - --lora_rank 32 \ # This enables LoRA - --use_wandb -``` - - -
- -
Other Training Arguments @@ -427,6 +397,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai - use_wandb: if this flag is up, you can view logs on wandb. +
+ +### Parameter Efficient Finetuning (PEFT) + +Currently, we have support LoRA (low-rank adaptation) and PiSSA (principal singular values and singular vectors adaptation). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. + + +
Low Rank Adaption and PiSSA + + +Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources. + +To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file. + +```json +{ + "r": 128, + "embedding_lora_dropout": 0.0, + "linear_lora_dropout": 0.1, + "lora_alpha": 32, + "lora_train_bias": "all", + "lora_initialization_method": "PiSSA", + "target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"] +} +``` +#### Lora Parameters +- r: lora rank +- embedding_lora_dropout: dropout probability for embedding layer +- linear_lora_dropout: dropout probability for linear layer +- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model. +- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases) +- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA. +- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it. + + +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin zero2_cpu \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --lr 2e-5 \ + --max_len 2048 \ + --lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA + --use_wandb +``` + +
@@ -445,7 +469,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the {"messages": [ { - "from": "human", + "from": "user", "content": "what are some pranks with a pen i can do?" }, { @@ -470,9 +494,15 @@ In this code we provide a flexible way for users to set the conversation templat - Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields. ```json { - "chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating, - "system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added, - "end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format, + "chat_template": "A string of chat_template used for formatting chat data", + "system_message": "A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added", + "end_of_assistant": "The token(s) in string that denotes the end of assistance's response", + "stop_ids": "A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training" + } + ``` + * `chat_template`: (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating. + * `system_message`: A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added. + * `end_of_assistant`: The token(s) in string that denotes the end of assistance's response". For example, in the ChatGLM2 prompt format, ``` <|im_start|>system system messages @@ -481,15 +511,13 @@ In this code we provide a flexible way for users to set the conversation templat <|im_start|>user How far is the moon? <|im_end|> <|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>... - ``` - the end_of_assistant tokens are "<|im_end|>" - "stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically - } - ``` - On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message), + ``` + the `end_of_assistant` tokens are "<|im_end|>" + * `stop_ids`: (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically. + On your first run of the data preparation script, you only need to define the `chat_template` (if you want to use custom chat template) and the `system message` (if you want to use a custom system message) -- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path. +- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path. - Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files. @@ -509,7 +537,7 @@ Human: what are some pranks with a pen i can do? Assistant: Are you #### Step 3: Training -Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. ### RLHF Training Stage2 - Training Reward Model @@ -526,7 +554,7 @@ Below shows the preference dataset format used in training the reward model. [ {"context": [ { - "from": "human", + "from": "user", "content": "Introduce butterflies species in Oregon." } ] @@ -551,11 +579,11 @@ Below shows the preference dataset format used in training the reward model. #### Step 2: Preprocessing -Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training. +Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training. #### Step 3: Training -You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +You can run [train_rm.sh](./training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. #### Features and Tricks in RM Training @@ -595,7 +623,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi #### Step 1: Data Collection -PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format. +PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format. ```json @@ -603,7 +631,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op {"messages": [ { - "from": "human", + "from": "user", "content": "what are some pranks with a pen i can do?" } ... @@ -626,14 +654,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to ] ``` #### Step 2: Preprocessing -To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh) +To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./data_preparation_scripts/prepare_prompt_dataset.sh) You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf). #### Step 3: Training -You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +You can run the [train_ppo.sh](./training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. ```bash @@ -717,17 +745,90 @@ For DPO training, you only need the preference dataset. Please follow the instru #### Step 2: Training -You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. +You can run the [train_dpo.sh](./training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options, +``` +--beta 0.1 \ # the temperature in DPO loss, Default to 0.1 +--gamma 0.0 \ # the reward target margin in the SimPO paper, Default to 0. +--disable_reference_model \ # whether to disable the reference model, if set, the implicit reward will be calculated solely from the actor. Default to enable reference model in DPO +--length_normalization \ # whether to apply length normalization, Default to not use +``` #### DPO Result

image

+### Alternative Option For RLHF: Simple Preference Optimization + +We support the method introduced in the paper [SimPO: Simple Preference Optimization +with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional. + +#### SimPO Result +

+image +

+ + +### Alternative Option For RLHF: Odds Ratio Preference Optimization +We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional. + +#### ORPO Result +

+image +

+ +### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO) +We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. + +For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examples/data_preparation_scripts/prepare_kto_dataset.sh). You will need preference data, different from DPO and its derivatives, you no longer need a pair of chosen/rejected response for the same input. You only need data whose response is associated with a preference label--- whether the response is okay or not, read the papre for more details. You also need to convert your data to the following intermediate format before you run the data preparation script. + +```jsonl +{ + "prompt": [ + { + "from": "user", + "content": "What are some praise words in english?" + }, + { + "from": "assistant", + "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..." + }, + { + "from": "user", + "content": "What's your favorite one?" + } + ], + "completion": { + "from": "assistant", + "content": "impressive." + }, + "label": true +} + +``` + +For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples). + +#### KTO Result +

+image +

## Hardware Requirements -For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM. + +For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention. +- 2 H800 GPU + - zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB + - zero2, micro batch size=4, VRAM Usage=72390.95 MB +- 4 H800 GPUs + - zero2_cpu, micro batch size=8, VRAM Usage=19412.77 MB + - zero2, micro batch size=8, VRAM Usage=43446.31 MB + - zero2, micro batch size=16, VRAM Usage=58082.30 MB + - zero2, micro batch size=8, lora_rank=8, VRAM Usage=21167.73 MB + - zero2, micro batch size=8, lora_rank=32, VRAM Usage=21344.17 MB + +For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model (llama2-7B-hf) on a dummy dataset with a sequence length of 2048 and a layout length of 512 with different tp_size (equal to the number of GPUs). | PPO | tp=8 | tp=4 | |-------|---------------|---------------| | bs=1 | 18485.19 MB | 42934.45 MB | @@ -738,12 +839,39 @@ For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length. - -- 1 H800 GPU - - zero2-cpu, batch size=2, VRAM Usage=49873.90 MB - - zero2-cpu, batch size=4, VRAM Usage=60998.22 MB +- 2 H800 GPU + - zero2-cpu, micro batch size=2, VRAM Usage=36989.37 MB + - zero2-cpu, micro batch size=4, VRAM Usage=48081.67 MB - 4 H800 GPUs - - zero2, batch size=4, VRAM Usage=67544.47 MB + - zero2, micro batch size=4, VRAM Usage=67483.44 MB + +For SimPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length. + +- 2 H800 GPU + - zero2-cpu, micro batch size=4, VRAM 25705.26 MB + - zero2, micro batch size=4, VRAM Usage=73375.04 MB +- 4 H800 GPUs + - zero2_cpu, micro batch size=8, VRAM Usage=36709.36 MB + - zero2, micro batch size=4, VRAM Usage=44330.90 MB + - zero2, micro batch size=8, VRAM Usage=56086.12 MB + +For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length. + +- 2 H800 GPU + - zero2-cpu, micro batch size=4, VRAM 26693.38 MB + - zero2, micro batch size=4, VRAM Usage=74332.65 MB +- 4 H800 GPUs + - zero2_cpu, micro batch size=8, VRAM Usage=38709.73 MB + - zero2, micro batch size=4, VRAM Usage=45309.52 MB + - zero2, micro batch size=8, VRAM Usage=58086.37 MB + +For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consumption on a dummy dataset with 2048 sequence length. +- 2 H800 GPU + - zero2-cpu, micro batch size=2, VRAM Usage=35241.98 MB + - zero2-cpu, micro batch size=4, VRAM Usage=38989.37 MB +- 4 H800 GPUs + - zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB + - zero2, micro batch size=4, VRAM_USAGE=59307.97 MB ## List of Supported Models diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py index 64093f88d..a35f2bf52 100644 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py @@ -40,7 +40,7 @@ import random import time from multiprocessing import cpu_count -from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf +from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft from datasets import dataset_dict, load_dataset from transformers import AutoTokenizer @@ -56,8 +56,8 @@ def main(): type=str, required=True, default=None, - choices=["sft", "prompt", "preference"], - help="Type of dataset, chose from 'sft', 'prompt', 'preference'.", + choices=["sft", "prompt", "preference", "kto"], + help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'", ) parser.add_argument( "--data_input_dirs", @@ -199,11 +199,13 @@ def main(): ) if args.type == "sft": - preparation_function = supervised_tokenize_sft + preparation_function = tokenize_sft elif args.type == "prompt": - preparation_function = tokenize_prompt_dataset + preparation_function = tokenize_prompt elif args.type == "preference": preparation_function = tokenize_rlhf + elif args.type == "kto": + preparation_function = tokenize_kto else: raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']") @@ -228,10 +230,13 @@ def main(): keep_in_memory=False, num_proc=min(len(dataset), cpu_count()), ) - - dataset = dataset.filter( - lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None - ) + if args.type == "kto": + filter_by = "completion" + elif args.type == "preference": + filter_by = "chosen_input_ids" + else: + filter_by = "input_ids" + dataset = dataset.filter(lambda data: data[filter_by] is not None) # Save each jsonl spliced dataset. output_index = "0" * (5 - len(str(index))) + str(index) diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh new file mode 100755 index 000000000..42c785289 --- /dev/null +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh @@ -0,0 +1,14 @@ +SAVE_DIR="" + +rm -rf $SAVE_DIR/cache +rm -rf $SAVE_DIR/jsonl +rm -rf $SAVE_DIR/arrow + +python prepare_dataset.py --type kto \ + --data_input_dirs /PATH/TO/KTO/DATASET \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow \ + --max_length 1024 diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh index 999d7778b..5c06b43fe 100755 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh @@ -5,9 +5,10 @@ rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/arrow python prepare_dataset.py --type preference \ - --data_input_dirs "PATH/TO/PREFERENCE/DATA" \ + --data_input_dirs /PATH/TO/PREFERENCE/DATASET \ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ --tokenizer_dir "" \ --data_cache_dir $SAVE_DIR/cache \ --data_jsonl_output_dir $SAVE_DIR/jsonl \ - --data_arrow_output_dir $SAVE_DIR/arrow + --data_arrow_output_dir $SAVE_DIR/arrow \ + --max_length 1024 diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh index 8d3d6c2c2..d74667889 100755 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh @@ -10,4 +10,5 @@ python prepare_dataset.py --type prompt \ --tokenizer_dir "" \ --data_cache_dir $SAVE_DIR/cache \ --data_jsonl_output_dir $SAVE_DIR/jsonl \ - --data_arrow_output_dir $SAVE_DIR/arrow + --data_arrow_output_dir $SAVE_DIR/arrow \ + --max_length 1024 diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh index 8562b47ee..84bae0027 100755 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh @@ -5,9 +5,10 @@ rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/arrow python prepare_dataset.py --type sft \ - --data_input_dirs "PATH/TO/SFT/DATA" \ + --data_input_dirs /PATH/TO/SFT/DATASET \ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ --tokenizer_dir "" \ --data_cache_dir $SAVE_DIR/cache \ --data_jsonl_output_dir $SAVE_DIR/jsonl \ --data_arrow_output_dir $SAVE_DIR/arrow \ + --max_length 4096 diff --git a/applications/ColossalChat/examples/inference/round.txt b/applications/ColossalChat/examples/inference/round.txt new file mode 100644 index 000000000..ba02074c1 --- /dev/null +++ b/applications/ColossalChat/examples/inference/round.txt @@ -0,0 +1,104 @@ + + +========== +round 1: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] Great, let’s hear a story. + +========== + + +========== +round 2: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 + +========== + + +========== +round 3: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 [INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. + +========== + + +========== +round 1: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +who is the first president of the USA? [/INST] The first president of the United States was George Washington. + +========== + + +========== +round 2: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +who is the first president of the USA? [/INST] The first president of the United States was George Washington. [INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. + +========== + + +========== +round 1: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? + +========== + + +========== +round 2: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? + +========== + + +========== +round 3: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? [INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. + +========== diff --git a/applications/ColossalChat/examples/requirements.txt b/applications/ColossalChat/examples/requirements.txt index 838590f4b..91f25a5cf 100644 --- a/applications/ColossalChat/examples/requirements.txt +++ b/applications/ColossalChat/examples/requirements.txt @@ -1,4 +1,4 @@ pandas>=1.4.1 sentencepiece -colossalai +colossalai==0.4.0 prompt_toolkit diff --git a/applications/ColossalChat/examples/training_scripts/hostfile b/applications/ColossalChat/examples/training_scripts/hostfile index c7aed75a3..2fbb50c4a 100755 --- a/applications/ColossalChat/examples/training_scripts/hostfile +++ b/applications/ColossalChat/examples/training_scripts/hostfile @@ -1,5 +1 @@ -XXX.XX.XXX.XXX # Your master IP -XXX.XX.XXX.XXX # Your slave IPs -XXX.XX.XXX.XXX # Your slave IPs -XXX.XX.XXX.XXX # Your slave IPs -XXX.XX.XXX.XXX # Your slave IPs +localhost diff --git a/applications/ColossalChat/examples/training_scripts/lora_config.json b/applications/ColossalChat/examples/training_scripts/lora_config.json new file mode 100644 index 000000000..4565f9e9b --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/lora_config.json @@ -0,0 +1,9 @@ +{ + "r": 128, + "embedding_lora_dropout": 0.0, + "linear_lora_dropout": 0.1, + "lora_alpha": 32, + "lora_train_bias": "all", + "lora_initialization_method": "PiSSA", + "target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"] +} diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index a5b4cb3bd..44131f572 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -6,7 +6,7 @@ from contextlib import nullcontext import torch from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset -from coati.models import convert_to_lora_module, disable_dropout +from coati.models import LoraConfig, convert_to_lora_module, disable_dropout from coati.trainer import DPOTrainer from coati.utils import load_checkpoint from transformers import AutoModelForCausalLM, AutoTokenizer @@ -23,8 +23,11 @@ logger = get_dist_logger() def train(args): + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) # check lora compatibility - if "gemini" in args.plugin and args.lora_rank > 0: + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") if args.plugin == "gemini_auto" and args.accumulation_steps > 1: raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") @@ -115,8 +118,8 @@ def train(args): coordinator.print_on_master(msg="Flash-attention enabled successfully") else: model = AutoModelForCausalLM.from_pretrained(args.pretrain) - disable_dropout(model) - if args.enable_reference_model: + + if not args.disable_reference_model: if args.use_flash_attn: ref_model = AutoModelForCausalLM.from_pretrained( args.pretrain, @@ -125,18 +128,20 @@ def train(args): ) else: ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) - disable_dropout(ref_model) else: ref_model = None + if args.lora_config is not None: + model = convert_to_lora_module(model, lora_config=lora_config) + for name, module in model.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) + disable_dropout(model) + disable_dropout(ref_model) - if args.lora_rank > 0: - model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) - - if args.grad_checkpoint and args.lora_rank == 0: - model.gradient_checkpointing_enable() + if args.grad_checkpoint: + # Note, for some models, lora may not be compatible with gradient checkpointing + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - elif args.lora_rank > 0: - coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") # configure tokenizer tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain @@ -178,6 +183,21 @@ def train(args): collate_fn=data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + eval_dataloader = None + if args.eval_dataset: + eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev") + eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) + + eval_dataloader = plugin.prepare_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=eval_data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps if args.warmup_steps is None: @@ -255,25 +275,28 @@ def train(args): save_interval=args.save_interval, save_dir=args.save_dir, coordinator=coordinator, + beta=args.beta, + gamma=args.gamma, + length_normalization=args.length_normalization, ) trainer.fit( train_preference_dataloader=train_dataloader, - eval_preference_dataloader=None, + eval_preference_dataloader=eval_dataloader, log_dir=args.log_dir, use_wandb=args.use_wandb, ) - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - + if lora_config is not None and lora_config.r > 0: # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True model.eval() # save model checkpoint after fitting on only rank0 - coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}") + if args.save_dir is not None: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}" + ) coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") @@ -296,6 +319,10 @@ if __name__ == "__main__": parser.add_argument("--tp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss") + parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss") + parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss") + parser.add_argument("--length_normalization", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true") @@ -304,33 +331,39 @@ if __name__ == "__main__": parser.add_argument("--model_type", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument("--eval_dataset", nargs="+", default=[]) parser.add_argument( "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" ) - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") - parser.add_argument("--save_dir", type=str, default="output") + parser.add_argument("--config_file", type=str, default=None, help="Config file") + parser.add_argument("--save_dir", type=str, default=None) parser.add_argument("--max_length", type=int, default=2048, help="Model max length") parser.add_argument("--max_epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=4) - parser.add_argument("--enable_reference_model", type=bool, default=True) - parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument( - "--lora_train_bias", - type=str, - default="none", - help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", + "--disable_reference_model", + action="store_true", + default=False, + help="Disable the reference model (enabled by default)", ) + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") - parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--accumulation_steps", type=int, default=8) - parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") args = parser.parse_args() - os.makedirs(os.path.dirname(args.config_file), exist_ok=True) - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) + + # fool proof hyperparameter setup + if args.loss_type == "simpo_loss": + args.length_normalization = True + args.gamma = args.gamma if args.gamma > 0 else 1.4 + + if args.config_file is not None: + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh index 80fc30c3d..4d49bc218 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.sh +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh @@ -13,50 +13,52 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 8 -# export CUDA_VISIBLE_DEVICES=6 +set_n_least_used_CUDA_VISIBLE_DEVICES 4 -PROJECT_NAME="dpo" +PROJECT_NAME="DPO" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PARENT_LOG_DIR="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path declare -a dataset=( - YOUR/DATA/DIR/arrow/part-00000 - YOUR/DATA/DIR/arrow/part-00001 - YOUR/DATA/DIR/arrow/part-00002 - YOUR/DATA/DIR/arrow/part-00003 - YOUR/DATA/DIR/arrow/part-00004 - YOUR/DATA/DIR/arrow/part-00005 - YOUR/DATA/DIR/arrow/part-00006 - YOUR/DATA/DIR/arrow/part-00007 - YOUR/DATA/DIR/arrow/part-00008 - YOUR/DATA/DIR/arrow/part-00009 + /Your/Preference/Data/arrow/part-00000 + /Your/Preference/Data/arrow/part-00001 + /Your/Preference/Data/arrow/part-00002 + /Your/Preference/Data/arrow/part-00003 + /Your/Preference/Data/arrow/part-00004 + /Your/Preference/Data/arrow/part-00005 + /Your/Preference/Data/arrow/part-00006 + /Your/Preference/Data/arrow/part-00007 + /Your/Preference/Data/arrow/part-00008 + /Your/Preference/Data/arrow/part-00009 ) TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" +LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}" -colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \ +colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \ --pretrain $PRETRAINED_MODEL_PATH \ - --checkpoint_path $PRETRAINED_MODEL_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ --save_interval 1000 \ --save_dir $SAVE_DIR \ --config_file $CONFIG_FILE \ + --log_dir $LOG_DIR \ --max_epochs 1 \ - --accumulation_steps 4 \ - --batch_size 2 \ + --accumulation_steps 2 \ + --batch_size 16 \ --lr 1e-6 \ + --beta 0.1 \ --mixed_precision "bf16" \ --grad_clip 1.0 \ + --max_length 4096 \ --weight_decay 0.01 \ - --warmup_steps 100 \ + --warmup_steps 60 \ --grad_checkpoint \ --use_wandb diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py new file mode 100755 index 000000000..d063b82bb --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -0,0 +1,376 @@ +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset +from coati.models import LoraConfig, convert_to_lora_module, disable_dropout +from coati.trainer import KTOTrainer +from coati.utils import load_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + +logger = get_dist_logger() + + +def train(args): + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) + # check lora compatibility + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + ref_booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + with init_ctx: + if args.use_flash_attn: + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + model = AutoModelForCausalLM.from_pretrained(args.pretrain) + + if args.use_flash_attn: + ref_model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + else: + ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) + if args.lora_config is not None: + model = convert_to_lora_module(model, lora_config=lora_config) + for name, module in model.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) + disable_dropout(ref_model) + disable_dropout(model) + + if args.grad_checkpoint: + # Note, for some models, lora may not be compatible with gradient checkpointing + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + # configure optimizer + optim = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map) + num_desirable = 0 + num_undesirable = 0 + for i in range(len(train_dataset)): + if train_dataset[i]["label"]: + num_desirable += 1 + else: + num_undesirable += 1 + logger.info(f"Dataset Statistics:\nDesirable: {num_desirable}\nUndesirable: {num_undesirable}") + + # Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306 + actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable) + if actual_ratio < 1 or actual_ratio > 4 / 3: + if not args.auto_weight: + raise AssertionError( + f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase/decrease desirable weight or decrease/increase undesirable weight." + ) + else: + args.desirable_weight = args.desirable_weight / actual_ratio + coordinator.print_on_master( + f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, auto weight is enabled, set desirable weight to {args.desirable_weight} and undesirable weight to {args.undesirable_weight}" + ) + + data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length) + + train_dataloader = plugin.prepare_dataloader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + eval_dataloader = None + if args.eval_dataset: + eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev") + eval_data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length) + + eval_dataloader = plugin.prepare_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=eval_data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") + + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + if args.warmup_steps is None: + args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optim, + total_steps=args.max_epochs * num_update_steps_per_epoch, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optim, _, train_dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + dataloader=train_dataloader, + ) + if ref_model is not None: + ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + sampler_start_idx = 0 + start_step = 0 + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") + booster.load_model(model, args.checkpoint_path) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) + assert isinstance(train_dataloader.sampler, StatefulDistributedSampler) + train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + trainer = KTOTrainer( + actor=model, + ref_model=ref_model, + booster=booster, + actor_optim=optim, + actor_lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + start_epoch=start_epoch, + save_interval=args.save_interval, + save_dir=args.save_dir, + coordinator=coordinator, + beta=args.beta, + desirable_weight=args.desirable_weight, + undesirable_weight=args.undesirable_weight, + ) + + trainer.fit( + train_preference_dataloader=train_dataloader, + eval_preference_dataloader=eval_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if lora_config is not None and lora_config.r > 0: + # NOTE: set model to eval to merge LoRA weights + model.eval() + # save model checkpoint after fitting on only rank0 + if args.save_dir is not None: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}" + ) + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss") + parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss") + parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss") + parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) + parser.add_argument("--zero_cpu_offload", default=False, action="store_true") + parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument("--eval_dataset", nargs="+", default=[]) + parser.add_argument( + "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" + ) + parser.add_argument("--config_file", type=str, default=None, help="Config file") + parser.add_argument("--save_dir", type=str, default=None) + parser.add_argument("--max_length", type=int, default=2048, help="Model max length") + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") + parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") + parser.add_argument("--auto_weight", default=False, action="store_true") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default=None, type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + if args.config_file is not None: + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.sh b/applications/ColossalChat/examples/training_scripts/train_kto.sh new file mode 100755 index 000000000..c28338c22 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_kto.sh @@ -0,0 +1,65 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +PROJECT_NAME="kto" +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PARENT_LOG_DIR="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path + +declare -a dataset=( + /Your/KTO/Data/arrow/part-00000 + /Your/KTO/Data/arrow/part-00001 + /Your/KTO/Data/arrow/part-00002 + /Your/KTO/Data/arrow/part-00003 + /Your/KTO/Data/arrow/part-00004 + /Your/KTO/Data/arrow/part-00005 + /Your/KTO/Data/arrow/part-00006 + /Your/KTO/Data/arrow/part-00007 + /Your/KTO/Data/arrow/part-00008 + /Your/KTO/Data/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}" + +colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --save_interval 1000 \ + --save_dir $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --log_dir $LOG_DIR \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --batch_size 8 \ + --auto_weight \ + --lr 1e-5 \ + --beta 0.1 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --max_length 1024 \ + --weight_decay 0.01 \ + --warmup_steps 60 \ + --grad_checkpoint diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py new file mode 100755 index 000000000..f06524507 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -0,0 +1,341 @@ +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset +from coati.models import LoraConfig, convert_to_lora_module, disable_dropout +from coati.trainer import ORPOTrainer +from coati.utils import load_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + +logger = get_dist_logger() + + +def train(args): + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) + # check lora compatibility + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + with init_ctx: + if args.use_flash_attn: + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + model = AutoModelForCausalLM.from_pretrained(args.pretrain) + if args.lora_config is not None: + model = convert_to_lora_module(model, lora_config=lora_config) + for name, module in model.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) + disable_dropout(model) + + if args.grad_checkpoint: + # Note, for some models, lora may not be compatible with gradient checkpointing + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + # configure optimizer + optim = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map) + data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) + + train_dataloader = plugin.prepare_dataloader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + + eval_dataloader = None + if args.eval_dataset: + eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev") + eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) + eval_dataloader = plugin.prepare_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=eval_data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") + + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + if args.warmup_steps is None: + args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optim, + total_steps=args.max_epochs * num_update_steps_per_epoch, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optim, _, train_dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + dataloader=train_dataloader, + ) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + sampler_start_idx = 0 + start_step = 0 + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") + booster.load_model(model, args.checkpoint_path) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) + assert isinstance(train_dataloader.sampler, StatefulDistributedSampler) + train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + trainer = ORPOTrainer( + actor=model, + booster=booster, + actor_optim=optim, + actor_lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + start_epoch=start_epoch, + save_interval=args.save_interval, + save_dir=args.save_dir, + coordinator=coordinator, + lam=args.lam, + ) + + trainer.fit( + train_preference_dataloader=train_dataloader, + eval_preference_dataloader=eval_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if lora_config is not None and lora_config.r > 0: + # NOTE: set model to eval to merge LoRA weights + model.eval() + # save model checkpoint after fitting on only rank0 + if args.save_dir is not None: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}" + ) + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss") + parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) + parser.add_argument("--zero_cpu_offload", default=False, action="store_true") + parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument("--eval_dataset", nargs="+", default=[]) + parser.add_argument( + "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" + ) + parser.add_argument("--config_file", type=str, default=None, help="Config file") + parser.add_argument("--save_dir", type=str, default=None) + parser.add_argument("--max_length", type=int, default=2048, help="Model max length") + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument( + "--disable_reference_model", + action="store_true", + default=False, + help="Disable the reference model (enabled by default)", + ) + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") + parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default=None, type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + if args.config_file is not None: + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.sh b/applications/ColossalChat/examples/training_scripts/train_orpo.sh new file mode 100755 index 000000000..48327e014 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.sh @@ -0,0 +1,64 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +PROJECT_NAME="ORPO" +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PARENT_LOG_DIR="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path + +declare -a dataset=( + /Your/Preference/Data/arrow/part-00000 + /Your/Preference/Data/arrow/part-00001 + /Your/Preference/Data/arrow/part-00002 + /Your/Preference/Data/arrow/part-00003 + /Your/Preference/Data/arrow/part-00004 + /Your/Preference/Data/arrow/part-00005 + /Your/Preference/Data/arrow/part-00006 + /Your/Preference/Data/arrow/part-00007 + /Your/Preference/Data/arrow/part-00008 + /Your/Preference/Data/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" +LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}" + +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --save_interval 1000 \ + --save_dir $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --log_dir $LOG_DIR \ + --max_epochs 3 \ + --accumulation_steps 1 \ + --batch_size 16 \ + --lr 8e-6 \ + --lam 0.5 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --max_length 1024 \ + --weight_decay 0.01 \ + --warmup_steps 60 \ + --grad_checkpoint \ + --use_wandb diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py index 3da3e9ca6..333be9963 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.py +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -13,7 +13,7 @@ from coati.dataset import ( load_tokenized_dataset, setup_conversation_template, ) -from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout +from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager from coati.trainer import PPOTrainer from coati.utils import load_checkpoint from transformers import AutoModelForCausalLM, AutoTokenizer @@ -31,8 +31,11 @@ logger = get_dist_logger() def train(args): + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) # check lora compatibility - if "gemini" in args.plugin and args.lora_rank > 0: + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") if args.plugin == "gemini_auto" and args.accumulation_steps > 1: raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") @@ -81,20 +84,26 @@ def train(args): ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True) reward_model = RewardModel(args.rm_pretrain) critic = Critic(args.rm_pretrain) + + if args.lora_config is not None: + actor = convert_to_lora_module(actor, lora_config=lora_config) + critic = convert_to_lora_module(critic, lora_config=lora_config) + for name, module in actor.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) + for name, module in critic.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) + lora_manager.able_to_merge = False + # Disable dropout disable_dropout(actor) disable_dropout(critic) - if args.lora_rank > 0: - actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias) - critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias) - - if args.grad_checkpoint and args.lora_rank == 0: - actor.gradient_checkpointing_enable() - critic.model.gradient_checkpointing_enable() + if args.grad_checkpoint: + actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - elif args.lora_rank > 0: - coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") # configure tokenizer tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain @@ -421,11 +430,9 @@ def train(args): use_wandb=args.use_wandb, ) - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - + if lora_config is not None and lora_config.r > 0: # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True + lora_manager.able_to_merge = True actor.eval() critic.eval() # save model checkpoint after fitting on only rank0 @@ -484,11 +491,9 @@ if __name__ == "__main__": parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--experience_batch_size", type=int, default=16) parser.add_argument("--ptx_batch_size", type=int, default=4) - parser.add_argument("--lora_train_bias", type=str, default="none") + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--accumulation_steps", type=int, default=8) - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--critic_lr", type=float, default=9e-6) parser.add_argument("--kl_coef", type=float, default=0.1) diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.sh b/applications/ColossalChat/examples/training_scripts/train_ppo.sh index 91633978e..277e75e6d 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.sh +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.sh @@ -15,10 +15,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { } set_n_least_used_CUDA_VISIBLE_DEVICES 8 -PROJECT_NAME="ppo" +PROJECT_NAME="PPO" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT) PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -54,7 +53,7 @@ declare -a ptx_dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \ --pretrain $PRETRAINED_MODEL_PATH \ diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py index ce0d02b5d..4c0a782b4 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.py +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -7,7 +7,7 @@ from contextlib import nullcontext import torch from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset -from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module +from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module from coati.trainer import RewardModelTrainer from coati.utils import load_checkpoint from transformers import AutoTokenizer @@ -16,14 +16,20 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer.policies.auto_policy import get_autopolicy +logger = get_dist_logger() + def train(args): + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) # check lora compatibility - if "gemini" in args.plugin and args.lora_rank > 0: + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") if args.plugin == "gemini_auto" and args.accumulation_steps > 1: raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") @@ -55,9 +61,11 @@ def train(args): args.pretrain, ) - if args.lora_rank > 0: - model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) - + if lora_config is not None: + model = convert_to_lora_module(model, lora_config=lora_config) + for name, module in model.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) # ============================== # Initialize Booster # ============================== @@ -119,11 +127,9 @@ def train(args): booster = Booster(plugin=plugin) - if args.grad_checkpoint and args.lora_rank == 0: - model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer + if args.grad_checkpoint: + model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - elif args.lora_rank > 0: - coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") # configure tokenizer tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain @@ -173,6 +179,22 @@ def train(args): collate_fn=data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + + eval_dataloader = None + if args.eval_dataset: + eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev") + eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) + eval_dataloader = plugin.prepare_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=eval_data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps math.ceil(args.max_epochs * num_update_steps_per_epoch) @@ -253,21 +275,21 @@ def train(args): trainer.fit( train_preference_dataloader=train_dataloader, - eval_preference_dataloader=None, + eval_preference_dataloader=eval_dataloader, log_dir=args.log_dir, use_wandb=args.use_wandb, ) - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - + if lora_config is not None and lora_config.r > 0: # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True model.eval() # save model checkpoint after fitting on only rank0 - coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}") + if args.save_dir is not None: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}" + ) coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") @@ -297,33 +319,28 @@ if __name__ == "__main__": parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument("--eval_dataset", nargs="+", default=[]) parser.add_argument( "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" ) - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") - parser.add_argument("--save_dir", type=str, default="output") + parser.add_argument("--config_file", type=str, default=None, help="Config file") + parser.add_argument("--save_dir", type=str, default=None) parser.add_argument("--max_length", type=int, default=2048, help="Model max length") parser.add_argument("--max_epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function") - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument( - "--lora_train_bias", - type=str, - default="none", - help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", - ) + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") - parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--accumulation_steps", type=int, default=8) - parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") args = parser.parse_args() - os.makedirs(os.path.dirname(args.config_file), exist_ok=True) - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) + if args.config_file is not None: + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.sh b/applications/ColossalChat/examples/training_scripts/train_rm.sh index e06d9092f..274417c03 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.sh +++ b/applications/ColossalChat/examples/training_scripts/train_rm.sh @@ -15,10 +15,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { } set_n_least_used_CUDA_VISIBLE_DEVICES 8 -PROJECT_NAME="rm" +PROJECT_NAME="RM" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PARENT_LOG_DIR="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path @@ -38,17 +38,18 @@ declare -a dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" +LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}" colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \ --pretrain $PRETRAINED_MODEL_PATH \ - --checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ --save_interval 1000 \ --save_dir $SAVE_DIR \ --config_file $CONFIG_FILE \ + --log_dir $LOG_DIR \ --max_epochs 3 \ --accumulation_steps 1 \ --batch_size 8 \ diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index 08e7550df..6007a8599 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -7,7 +7,7 @@ from contextlib import nullcontext import torch from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset -from coati.models import convert_to_lora_module +from coati.models import LoraConfig, convert_to_lora_module from coati.trainer import SFTTrainer from coati.utils import load_checkpoint from transformers import AutoModelForCausalLM, AutoTokenizer @@ -24,8 +24,11 @@ logger = get_dist_logger() def train(args): + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) # check lora compatibility - if "gemini" in args.plugin and args.lora_rank > 0: + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") if args.plugin == "gemini_auto" and args.accumulation_steps > 1: raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") @@ -53,15 +56,19 @@ def train(args): torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, trust_remote_code=True, ) - if args.lora_rank > 0: - model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) + + if lora_config is not None: + model = convert_to_lora_module(model, lora_config=lora_config) + for name, module in model.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) if args.plugin == "ddp": """ Default torch ddp plugin without any acceleration, for debugging purpose acceleration, for debugging purpose """ - plugin = TorchDDPPlugin(find_unused_parameters=True) + plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) elif args.plugin == "gemini": plugin = GeminiPlugin( precision=args.mixed_precision, @@ -114,6 +121,15 @@ def train(args): booster = Booster(plugin=plugin) + # configure optimizer + optim = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== @@ -122,12 +138,10 @@ def train(args): # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() # ) - if args.grad_checkpoint and args.lora_rank == 0: - # lora layers are not supported by gradient checkpointing - model.gradient_checkpointing_enable() + if args.grad_checkpoint: + # Note, for some models, lora may not be compatible with gradient checkpointing + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - elif args.lora_rank > 0: - coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") # configure tokenizer tokenizer = AutoTokenizer.from_pretrained( @@ -151,15 +165,6 @@ def train(args): coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}") - # configure optimizer - optim = HybridAdam( - model_params=model.parameters(), - lr=args.lr, - betas=(0.9, 0.95), - weight_decay=args.weight_decay, - adamw_mode=True, - ) - # configure dataset coordinator.print_on_master( f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" @@ -175,6 +180,23 @@ def train(args): collate_fn=data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + + eval_dataloader = None + if args.eval_dataset: + eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev") + eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len) + + eval_dataloader = plugin.prepare_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=eval_data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") + coordinator.print_on_master( f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" ) @@ -202,6 +224,7 @@ def train(args): lr_scheduler=lr_scheduler, dataloader=train_dataloader, ) + torch.set_default_dtype(torch.float) coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") @@ -257,22 +280,21 @@ def train(args): trainer.fit( train_dataloader=train_dataloader, - eval_dataloader=None, + eval_dataloader=eval_dataloader, log_dir=args.log_dir, use_wandb=args.use_wandb, ) - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - + if lora_config is not None and lora_config.r > 0: # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True model.eval() # save model checkpoint after fitting on only rank0 - coordinator.print_on_master("Start saving final model checkpoint") - - # booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}") + if args.save_path is not None: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}" + ) coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") @@ -302,32 +324,27 @@ if __name__ == "__main__": parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument("--eval_dataset", nargs="+", default=[]) parser.add_argument( "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" ) - parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--save_path", type=str, default=None) parser.add_argument("--max_epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument( - "--lora_train_bias", - type=str, - default="none", - help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", - ) + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") - parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=5e-6) - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + parser.add_argument("--config_file", type=str, default=None, help="Config file") parser.add_argument("--accumulation_steps", type=int, default=8) - parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") args = parser.parse_args() - os.makedirs(os.path.dirname(args.config_file), exist_ok=True) - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) + if args.config_file is not None: + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh index 53c712901..e87184c81 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.sh +++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh @@ -13,13 +13,11 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } - -# export CUDA_VISIBLE_DEVICES=4,5,6 -set_n_least_used_CUDA_VISIBLE_DEVICES 2 -PROJECT_NAME="sft" +set_n_least_used_CUDA_VISIBLE_DEVICES 4 +PROJECT_NAME="SFT" PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PARENT_LOG_DIR="" # Path to a folder to save training config logs PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path declare -a dataset=( @@ -38,28 +36,25 @@ declare -a dataset=( TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" -CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" +LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}" echo $(which colossalai) echo $(which python) # the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size -colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \ +colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \ --pretrain $PRETRAINED_MODEL_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ - --save_interval 4000 \ + --save_interval 2000 \ --dataset ${dataset[@]} \ - --save_path $SAVE_DIR \ - --config_file $CONFIG_FILE \ - --lora_rank 0 \ - --plugin 3d \ - --tp 2 \ - --pp 1 \ - --zero_stage 0 \ - --batch_size 2 \ - --max_epochs 3 \ + --plugin zero2 \ + --batch_size 8 \ + --max_epochs 1 \ --accumulation_steps 1 \ --lr 5e-5 \ - --max_len 400 \ + --max_len 4096 \ + --use_flash_attn \ --grad_checkpoint \ - --use_wandb \ - --use_flash_attn + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --log_dir $LOG_DIR \ diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index ef3a5a0e8..2188de12f 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,9 +1,9 @@ -transformers>=4.36.2 +transformers==4.39.3 tqdm datasets==2.14.7 loralib -colossalai>=0.3.7 -torch>=1.12.1 +colossalai==0.4.0 +torch>=2.1.0 langchain tokenizers fastapi diff --git a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py index 9f85b4beb..e50b20b6b 100644 --- a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py +++ b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py @@ -4,7 +4,7 @@ import os sft_seed = { "messages": [ - {"from": "human", "content": "Give three tips for staying healthy."}, + {"from": "user", "content": "Give three tips for staying healthy."}, { "from": "assistant", "content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.", @@ -13,7 +13,7 @@ sft_seed = { } prompt_seed = { "messages": [ - {"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."}, + {"from": "user", "content": "Describe the impacts of climate change on communities living in coastal areas."}, { "from": "assistant", "content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.", @@ -22,21 +22,34 @@ prompt_seed = { } preference_seed = { "context": [ - {"from": "human", "content": "What kind of noises did dinosaurs make?"}, + {"from": "user", "content": "What kind of noises did dinosaurs make?"}, { "from": "assistant", "content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be", }, - {"from": "human", "content": "yes they did"}, + {"from": "user", "content": "yes they did"}, { "from": "assistant", "content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.", }, - {"from": "human", "content": "you cant read"}, + {"from": "user", "content": "you cant read"}, ], "chosen": [{"from": "assistant", "content": "You can read?"}], "rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}], } +kto_seed = { + "prompt": [ + {"from": "user", "content": "What are some praise words in english?"}, + { + "from": "assistant", + "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ...", + }, + {"from": "user", "content": "What's your favorite one?"}, + ], + "completion": {"from": "assistant", "content": "Impressive."}, + "label": True, +} + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -61,12 +74,21 @@ if __name__ == "__main__": seed = prompt_seed elif args.data_type == "preference": seed = preference_seed + elif args.data_type == "kto": + seed = kto_seed else: raise ValueError(f"Unknown data type {args.data_type}") - - line = json.dumps(seed, ensure_ascii=False) + "\n" - for idx in [1, 2, 3]: - with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f: - for i in range(1000): + if args.data_type != "kto": + line = json.dumps(seed, ensure_ascii=False) + "\n" + for idx in [1, 2, 3]: + with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f: + for i in range(1000): + f.write(line) f.write(line) - f.write(line) + else: + for idx in [1, 2, 3]: + with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f: + for i in range(1000): + seed["label"] = not seed["label"] + line = json.dumps(seed, ensure_ascii=False) + "\n" + f.write(line) diff --git a/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl index 2e11a91c6..0f9a02ea3 100644 --- a/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl +++ b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl @@ -1 +1 @@ -{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]} +{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]} diff --git a/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl b/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl new file mode 100644 index 000000000..4f4fce83d --- /dev/null +++ b/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl @@ -0,0 +1 @@ +{"prompt": [{"from": "user", "content": "What are some praise words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "impressive."},"label": true} diff --git a/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl index 21c4d9dc7..759bba7a0 100644 --- a/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl +++ b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl @@ -1 +1 @@ -{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]} +{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]} diff --git a/applications/ColossalChat/tests/test_data_preparation.sh b/applications/ColossalChat/tests/test_data_preparation.sh index a7689cdc6..427c3952b 100755 --- a/applications/ColossalChat/tests/test_data_preparation.sh +++ b/applications/ColossalChat/tests/test_data_preparation.sh @@ -71,6 +71,8 @@ get_data_input_dirs() { echo "$PROMPT_DATASET" elif [[ $data_type == "preference" ]]; then echo "$PREFERENCE_DATASET" + elif [[ $data_type == "kto" ]]; then + echo "$KTO_DATASET" else echo "Unknown data type $data_type" exit 1 @@ -121,6 +123,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \ --data_dir $(get_data_input_dirs prompt) \ --data_type "prompt" +python $TEST_DIR/generate_dummy_datasets_for_testing.py \ + --data_dir $(get_data_input_dirs kto) \ + --data_type "kto" + echo "[Test]: testing prepare_preference_dataset.py ..." # FIXME: This is a hack to skip tests that are not working @@ -258,3 +264,50 @@ for model in ${MODELS[@]}; do exit 1 fi done + + +echo "[Test]: testing prepare_kto_dataset.py ..." + +# FIXME: This is a hack to skip tests that are not working +SKIPPED_TESTS=( +) + +# test prepare_kto_dataset +for model in ${MODELS[@]}; do + data_type="kto" + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then + echo "[Test]: Skipped $model-$data_type" + continue + fi + cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache + jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl + arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow + data_input_dirs=$(get_data_input_dirs $data_type) + tokenizer_dir=$(get_tokenizer_dirs $model) + conversation_template=$(get_conversation_template_config $model) + for i in $(seq $NUM_RETRY); do + rm -rf $cache_dir + rm -rf $jsonl_dir + rm -rf $arrow_dir + echo "[Test]: $model-$data_type, attempt $i" + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \ + --type kto \ + --data_input_dirs $data_input_dirs \ + --conversation_template_config $conversation_template \ + --tokenizer_dir $tokenizer_dir \ + --data_cache_dir $cache_dir \ + --data_jsonl_output_dir $jsonl_dir \ + --data_arrow_output_dir $arrow_dir \ + --max_length 400 \ + --num_samples_per_datafile 100 \ + --num_spliced_dataset_bins 1 + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$data_type" + exit 1 + fi +done diff --git a/applications/ColossalChat/tests/test_lora.py b/applications/ColossalChat/tests/test_lora.py index 4ea9e1a15..778759210 100755 --- a/applications/ColossalChat/tests/test_lora.py +++ b/applications/ColossalChat/tests/test_lora.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.optim as optim from coati.models import convert_to_lora_module +from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear from torch.utils.data import DataLoader, TensorDataset @@ -38,7 +39,7 @@ def test_overfit(): # Build and convert model model = SimpleNN(input_size, hidden_size, num_classes) weight_to_compare = model.fc1.weight.detach().clone() - model = convert_to_lora_module(model, lora_rank=30) + model = convert_to_lora_module(model, lora_config=LoraConfig(r=32)) # Loss and optimizer criterion = nn.CrossEntropyLoss() @@ -50,7 +51,6 @@ def test_overfit(): # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) - print(loss) # Backward and optimize optimizer.zero_grad() loss.backward() @@ -65,5 +65,50 @@ def test_overfit(): assert (weight_to_compare - model.fc1.weight).sum() < 0.01 +def test_lora_linear_accuracy(): + + weight = torch.randn(10, 5) + linear = nn.Linear(5, 10) + linear.weight.data = weight + x = torch.randn(10, 5) + out_linear = linear(x) + + # lora linear Pissa + linear.weight.data = weight + lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA") + out_lora = lora_linear(x) + assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05) + + # lora linear + linear.weight.data = weight + lora_linear = LoraLinear(linear.weight, linear.bias, r=2) + out_lora = lora_linear(x) + assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05) + + +def test_lora_embedding_accuracy(): + weight = torch.randn(10, 5) + embedding = nn.Embedding(10, 5) + embedding.weight.data = weight + x = torch.randint(0, 10, (10,)) + out_embedding = embedding(x) + + # lora embedding Pissa + embedding.weight.data = weight + lora_embedding = LoraEmbedding( + embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5 + ) + out_lora = lora_embedding(x) + assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05) + + # lora embedding + embedding.weight.data = weight + lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5) + out_lora = lora_embedding(x) + assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05) + + if __name__ == "__main__": test_overfit() + test_lora_linear_accuracy() + test_lora_embedding_accuracy() diff --git a/applications/ColossalChat/tests/test_templating.sh b/applications/ColossalChat/tests/test_templating.sh index d033c07f5..6ee10e8be 100755 --- a/applications/ColossalChat/tests/test_templating.sh +++ b/applications/ColossalChat/tests/test_templating.sh @@ -94,7 +94,7 @@ done # Test DPO/PPO data Preparation for model in ${MODELS[@]}; do - echo "Testing DPO/PPO data templating for $model" + echo "Testing DPO/RM data templating for $model" SAVE_DIR=$DATA_SAVE_PATH/dpo/$model rm -rf $SAVE_DIR/cache rm -rf $SAVE_DIR/jsonl @@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do --data_arrow_output_dir $SAVE_DIR/arrow passed=$? if [ $passed -ne 0 ]; then - echo "[Test]: Failed in the DPO data templating for $model" + echo "[Test]: Failed in the DPO/RM data templating for $model" exit 1 fi python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \ --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo passed=$? if [ $passed -ne 0 ]; then - echo "[Test]: Failed in the DPO data templating test for $model" + echo "[Test]: Failed in the DPO/RM data templating test for $model" + exit 1 + fi +done + + +# Test KTO data Preparation +for model in ${MODELS[@]}; do + echo "Testing KTO data templating for $model" + SAVE_DIR=$DATA_SAVE_PATH/kto/$model + rm -rf $SAVE_DIR/cache + rm -rf $SAVE_DIR/jsonl + rm -rf $SAVE_DIR/arrow + pretrain=$(get_pretrain $model) + conversation_template_config=$(get_conversation_template_config $model) + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \ + --tokenizer_dir $pretrain \ + --conversation_template_config $conversation_template_config \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow + passed=$? + if [ $passed -ne 0 ]; then + echo "[Test]: Failed in the KTO data templating for $model" + exit 1 + fi + python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \ + --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto + passed=$? + if [ $passed -ne 0 ]; then + echo "[Test]: Failed in the KTO data templating test for $model" exit 1 fi done diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index d1a685174..c26b25c83 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -30,9 +30,10 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy -PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy +PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally +LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" export OMP_NUM_THREADS=8 @@ -112,6 +113,11 @@ for lora_rank in ${LORA_RANK[@]}; do sp='1' sp_mode='split_gather' enable_sequence_parallelism='' + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi if [[ $plugin == "3d" ]]; then tp='4' bs='8' @@ -173,9 +179,10 @@ for lora_rank in ${LORA_RANK[@]}; do --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ --save_path $MODEL_SAVE_PATH \ --config_file $MODELS_DIR/config.jsonl \ - --lora_rank $lora_rank \ + $lora_config \ --plugin $plugin \ --batch_size $bs \ --max_epochs 1 \ @@ -192,8 +199,8 @@ for lora_rank in ${LORA_RANK[@]}; do --use_flash_attn passed=$? if [ $passed -eq 0 ]; then - rm -rf $MODEL_SAVE_PATH/* - rm -rf $MODELS_DIR/* + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* break fi done @@ -229,6 +236,11 @@ for lora_rank in ${LORA_RANK[@]}; do grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") tp='1' bs='2' + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi if [[ $plugin == "3d" ]]; then tp='4' bs='8' @@ -248,9 +260,10 @@ for lora_rank in ${LORA_RANK[@]}; do --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ --save_dir $MODEL_SAVE_PATH \ --config_file $MODELS_DIR/config.jsonl \ - --lora_rank $lora_rank \ + $lora_config \ --plugin $plugin \ --batch_size $bs \ --max_epochs 1 \ @@ -262,8 +275,8 @@ for lora_rank in ${LORA_RANK[@]}; do --use_flash_attn passed=$? if [ $passed -eq 0 ]; then - rm -rf $MODEL_SAVE_PATH/* - rm -rf $MODELS_DIR/* + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* break fi done @@ -306,6 +319,11 @@ for lora_rank in ${LORA_RANK[@]}; do bs='4' ebs='8' conversation_template=$(get_conversation_template_config $model) + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi if [[ $plugin == "3d" ]]; then tp='4' bs='16' @@ -342,7 +360,7 @@ for lora_rank in ${LORA_RANK[@]}; do --ptx_batch_size 1 \ --ptx_coef 0.2 \ --save_path $MODEL_SAVE_PATH \ - --lora_rank $lora_rank \ + $lora_config \ --plugin $plugin \ --num_episodes 5 \ --num_collect_steps 1 \ @@ -361,8 +379,8 @@ for lora_rank in ${LORA_RANK[@]}; do # --use_flash_attn passed=$? if [ $passed -eq 0 ]; then - rm -rf $MODEL_SAVE_PATH/* - rm -rf $MODELS_DIR/* + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* break fi done @@ -402,6 +420,11 @@ for lora_rank in ${LORA_RANK[@]}; do tp='4' bs='8' fi + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi grad_accu='2' # gemini_auto and gemini doesn't support gradient accumulation if [[ $plugin == "gemini_auto" ]]; then @@ -423,9 +446,10 @@ for lora_rank in ${LORA_RANK[@]}; do --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ --save_dir $MODEL_SAVE_PATH \ --config_file $MODELS_DIR/config.jsonl \ - --lora_rank $lora_rank \ + $lora_config \ --plugin $plugin \ --batch_size $bs \ --max_epochs 1 \ @@ -437,8 +461,176 @@ for lora_rank in ${LORA_RANK[@]}; do --use_flash_attn passed=$? if [ $passed -eq 0 ]; then - rm -rf $MODEL_SAVE_PATH/* - rm -rf $MODELS_DIR/* + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done + + + +echo "[Test]: testing ORPO ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto doesn't support generation + # (need to calculate ref_model logits through forwarding in inference mode) + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ + --save_dir $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done + + + +echo "[Test]: testing KTO ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto doesn't support generation + # (need to calculate ref_model logits through forwarding in inference mode) + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ + --save_dir $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + --auto_weight \ + --desirable_weight 1.2 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* break fi done diff --git a/applications/ColossalChat/tests/verify_chat_data.py b/applications/ColossalChat/tests/verify_chat_data.py index 98ae0c1b2..eb8f9ce46 100644 --- a/applications/ColossalChat/tests/verify_chat_data.py +++ b/applications/ColossalChat/tests/verify_chat_data.py @@ -62,3 +62,11 @@ if __name__ == "__main__": assert any( [rejected_lable in s for s in to_verify_lable_rejected] ), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}" + elif args.data_type == "kto": + sample = data[0] + to_verify_data = to_verify_data[0] + for line in sample["prompt"]: + assert line["content"] in to_verify_data["input_id_decode"] + assert sample["completion"]["content"] in to_verify_data["input_id_decode"] + assert sample["completion"]["content"] in to_verify_data["completion_decode"] + assert sample["label"] == to_verify_data["label"] diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index d5f230249..c1cfe37d7 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} files = glob.glob(os.path.join(path, "*.jsonl")) diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py index 531313d7e..a29f56fd1 100644 --- a/applications/ColossalEval/colossal_eval/dataset/base.py +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -1,6 +1,9 @@ from abc import abstractstaticmethod from colossal_eval.utils import jdump +from torch.utils.data import Dataset + +from colossalai.logging import DistributedLogger class BaseDataset: @@ -12,13 +15,24 @@ class BaseDataset: logger: Logger for the dataset. """ - def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False): - self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference) + def __init__(self, path, logger, *args, **kwargs): + self.dataset = self.load(path, logger, *args, **kwargs) def save(self, save_path): """Save the converted dataset""" jdump(self.dataset, save_path) @abstractstaticmethod - def load(path, logger): + def load(path, logger: DistributedLogger, *args, **kwargs): """Load the original dataset and convert it into the inference dataset""" + + +class DistributedDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 915f4d9b0..1023d1e23 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 477280663..05752c248 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py index 54ea478ae..0337454fa 100644 --- a/applications/ColossalEval/colossal_eval/dataset/colossalai.py +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} data = jload(path) data_per_category = get_data_per_category(data) diff --git a/applications/ColossalEval/colossal_eval/dataset/cvalues.py b/applications/ColossalEval/colossal_eval/dataset/cvalues.py index 30e802a02..4023a4c76 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cvalues.py +++ b/applications/ColossalEval/colossal_eval/dataset/cvalues.py @@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl") data_list = [] diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index cda6276bf..44ccea9cf 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: files = os.listdir(os.path.join(path, "data", category)) diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py index 9ea5e3c7d..eb61efaa0 100644 --- a/applications/ColossalEval/colossal_eval/dataset/longbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} files = os.listdir(path) diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index dcda68e8f..e9465c91b 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index 031415567..ef474ec4c 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset): This dataset class will convert the original dataset into the inference dataset. """ - def __init__(self, path, logger, few_shot): + def __init__(self, path, logger: DistributedLogger, *args, **kwargs): self.multiturn = True - self.dataset = self.load(path, logger, few_shot) + self.dataset = self.load(path, logger, *args, **kwargs) @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": defaultdict(dict)} file_path = os.path.join(path, "question.jsonl") diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py index e77a3da34..8056c3dfd 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py @@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} data_files = [os.path.join(path, file_name) for file_name in FILES] for file_path in data_files: diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py index 3eca808bb..f5f17e64c 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py @@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} data_files = [os.path.join(path, file_name) for file_name in FILES] for file_path in data_files: diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 23c399cce..e91743525 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -1,11 +1,11 @@ import copy -import math from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 from peft import PeftModel +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer @@ -130,7 +130,7 @@ class HuggingFaceModel(BaseModel): if shard_config is not None: self.model = AutoModel.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) - self.model, sharded_parameters = shard_former.optimize(self.model) + self.model, _ = shard_former.optimize(self.model) self.model.to(get_current_device()) if peft_path is not None: @@ -325,7 +325,7 @@ class HuggingFaceModel(BaseModel): return input_ids_list, labels_list, None - def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: """ Infer the given data. This function will call self.generate() to get model outputs and also self.model() to get logits. @@ -359,26 +359,23 @@ class HuggingFaceModel(BaseModel): self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} - turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1 - turn_desc = "" if turn == 0 else f"-turn{turn}" - bar = tqdm( - range(math.ceil(len(data) / self.batch_size)), - desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps", + range(len(data_loader)), + desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps", disable=not is_rank_0(), ) loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - answers = copy.deepcopy(data) - for i in range(0, len(data), self.batch_size): - batch = data[i : i + self.batch_size] + answers = [] + + for i, batch in enumerate(data_loader): batch_prompt, batch_target = get_batch_prompt( - self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length + self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length ) if is_rank_0() and debug and i == 0: self.logger.info( - f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}" + f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}" ) self.logger.info("-" * 120) self.logger.info("An example prompt and prompt with target is:") @@ -402,7 +399,7 @@ class HuggingFaceModel(BaseModel): # Otherwise this will violate the single-choice setting. if calculate_loss: - labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))] + labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))] loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() @@ -411,29 +408,30 @@ class HuggingFaceModel(BaseModel): {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) ] - for j in range(len(batch_prompt)): + for j in range(len(batch)): if not pretrain: - if isinstance(answers[i + j]["output"], list): - answers[i + j]["output"].append(batch_decodes[j].strip()) + if isinstance(batch[j]["output"], list): + batch[j]["output"].append(batch_decodes[j].strip()) else: - answers[i + j]["output"] = batch_decodes[j].strip() + batch[j]["output"] = batch_decodes[j].strip() if isinstance(scores, torch.Tensor): - answers[i + j]["logits_over_choices"] = probs[j] + batch[j]["logits_over_choices"] = probs[j] if calculate_loss: - answers[i + j]["loss_over_choices"] = loss_over_choices[j] + batch[j]["loss_over_choices"] = loss_over_choices[j] if calculate_loss: - answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. # However, loss (which is per sample loss) suffices for most cases. - answers[i + j]["loss_sum"] = batch_losses[j] - answers[i + j]["token_num"] = batch_target_token_nums[j] + batch[j]["loss_sum"] = batch_losses[j] + batch[j]["token_num"] = batch_target_token_nums[j] if batch_bytes_nums: - answers[i + j]["byte_num"] = batch_bytes_nums[j] + batch[j]["byte_num"] = batch_bytes_nums[j] + answers.extend(batch) bar.update() @@ -600,7 +598,7 @@ class HuggingFaceCausalLM(HuggingFaceModel): if shard_config is not None: self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) - self.model, sharded_parameters = shard_former.optimize(self.model) + self.model, _ = shard_former.optimize(self.model) self.model.to(get_current_device()) if peft_path is not None: diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 330083aa6..c0445e84e 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -123,15 +123,13 @@ class Conversation: } -def get_few_shot_prefix( - conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int -) -> str: +def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str: """ Get few shot prefix. Args: - conv: Conversation template. - few_shot_examples: Few shot examples to generate few shot prompt prefix. + few_shot_data: Few shot examples to generate few shot prompt prefix. + tokenizer: tokenizer used to tokenize data. Returns: Few shot prompt prefix. @@ -157,7 +155,6 @@ def get_batch_prompt( batch: List[Dict], few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], - language: Optional[str], model_max_length: Optional[int], ) -> Tuple[List[Dict], List[Dict]]: """ @@ -167,6 +164,7 @@ def get_batch_prompt( conv: Conversation template. batch: Batch data to generate prompt from. few_shot_data: Few shot data to generate few shot prompt prefix. + tokenizer: tokenizer used to tokenize data. Returns: Tuple containg batch prompt and target. @@ -192,7 +190,7 @@ def get_batch_prompt( else: raise Exception("When using few-shot, target answer should be a string.") - few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) + few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens) conv.append_message(conv.roles[0], few_shot_prefix + query_text) conv.append_message(conv.roles[1], None) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index a7307635d..c651970ee 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -5,6 +5,8 @@ from typing import Dict, List import torch.distributed as dist from colossal_eval import dataset, models, utils +from colossal_eval.dataset.base import DistributedDataset +from torch.utils.data import DataLoader, DistributedSampler import colossalai from colossalai.accelerator import get_accelerator @@ -13,6 +15,7 @@ from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig logger = get_dist_logger() +os.environ["TOKENIZERS_PARALLELISM"] = "false" def rm_and_merge( @@ -54,7 +57,8 @@ def rm_and_merge( ) else: rank_answers = utils.jload(directory) - answers["data"].extend(rank_answers["data"]) + deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]] + answers["data"].extend(deduplidate_answers) answers["inference_kwargs"] = rank_answers["inference_kwargs"] for r in range(dp_size): @@ -65,7 +69,7 @@ def rm_and_merge( os.remove(directory) except Exception as e: print(e) - + print(len(answers["data"])) all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers @@ -108,7 +112,12 @@ def main(args): tp_rank = coordinates[TP_AXIS] shard_config = ( - ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1) + ShardConfig( + tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=args.tp_size > 1, + parallel_output=False, + enable_all_optimization=True, + ) if args.tp_size > 1 else None ) @@ -183,6 +192,7 @@ def main(args): model_name = model_parameter["name"] model_class = eval(f"models.{model_parameter['model_class']}") paramerters = model_parameter["parameters"] + batch_size = paramerters["batch_size"] paramerters.update({"logger": logger}) paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) paramerters.update({"shard_config": shard_config}) @@ -192,7 +202,6 @@ def main(args): raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") for dataset_name, split_data in inference_data.items(): - start = 0 prev_questions = None for category, category_data in split_data.items(): num_turn = category_data["inference_kwargs"].get("turns", 1) @@ -201,26 +210,33 @@ def main(args): raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") answers_to_dump = copy.deepcopy(category_data) - partition_size = len(category_data["data"]) // dp_size - redundant = len(category_data["data"]) % dp_size - - # Ensure that the amount of data for inference is as consistent as possible across different processes. - lengths = [partition_size for _ in range(dp_size)] - for j in range(redundant): - lengths[(j + start) % dp_size] += 1 - - start = (start + redundant) % dp_size - for turn in range(num_turn): if turn == 0: - questions = category_data["data"][ - sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank] - ] + dist_dataset = DistributedDataset(category_data["data"]) else: - questions = prev_questions + dist_dataset = DistributedDataset(prev_questions) + + sampler = DistributedSampler( + dist_dataset, + num_replicas=pg_mesh.size(DP_AXIS), + rank=pg_mesh.coordinate(DP_AXIS), + shuffle=False, + ) + questions_loader = DataLoader( + dist_dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=8, + pin_memory=True, + collate_fn=lambda x: x, + ) + category_data["inference_kwargs"]["dataset"] = dataset_name + category_data["inference_kwargs"]["category"] = category answers_per_rank = model_.inference( - questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + data_loader=questions_loader, + inference_kwargs=category_data["inference_kwargs"], + debug=debug_args[dataset_name], ) prev_questions = answers_per_rank diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c210ca91e..bf0788650 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -2,7 +2,7 @@ import ctypes import random import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from copy import deepcopy from functools import partial from types import MethodType @@ -30,11 +30,15 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor.api import is_distributed_tensor +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle from .pp_plugin_base import PipelinePluginBase @@ -61,6 +65,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): use_ddp: bool, ddp_config: dict, custom_policy: Policy, + overlap_allgather: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -69,6 +74,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.sp_group = sp_group self.use_dpp = use_ddp self.require_grad_sync = True + self.overlap_allgather = overlap_allgather shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -106,6 +112,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = DDP(module, process_group=dp_group, **ddp_config) super().__init__(module) + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) def sync_shared_params(self): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): @@ -197,7 +209,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) + with self._wait_all_gather(): + return super().forward(*args, **kwargs) def unwrap(self): module = super().unwrap() @@ -205,6 +218,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = module.module return module + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + def _wait_all_gather(self): + return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: @@ -235,7 +255,7 @@ def get_param_info(optim: Optimizer): return param_info -def init_pipeline_optimizer(optim: Optimizer, model: Module): +def reinitialize_optimizer(optim: Optimizer, model: Module): model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: @@ -257,7 +277,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ): self.param_info = param_info if use_pipeline: - init_pipeline_optimizer(optim, model) + reinitialize_optimizer(optim, model) self.model = model self.stage_manager = model.stage_manager self.shared_params = model.shared_params @@ -478,7 +498,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 if use_pipeline: - init_pipeline_optimizer(optim, model) + reinitialize_optimizer(optim, model) super().__init__( optim, precision=precision, @@ -632,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): model: HybridParallelModule, use_pipeline: bool, param_info: OrderedDict, + pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -650,6 +671,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): tp_process_group: Optional[ProcessGroup] = None, # if using tp pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, ): self.model = model self.param_info = param_info @@ -658,11 +680,12 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): self.tp_pg = tp_process_group self.pp_pg = pp_process_group if use_pipeline: - init_pipeline_optimizer(optimizer, model) + reinitialize_optimizer(optimizer, model) super().__init__( optimizer=optimizer, initial_scale=initial_scale, min_scale=min_scale, + pg_to_param_list=pg_to_param_list, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, @@ -677,6 +700,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, + overlap_allgather=overlap_allgather, ) def sync_dp_grads(self): @@ -993,9 +1017,11 @@ class HybridParallelPlugin(PipelinePluginBase): make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, overlap_p2p: bool = True, + overlap_allgather: bool = False, fp8_communication: bool = False, ) -> None: super().__init__() + assert ( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" @@ -1038,17 +1064,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism if dp_outside: - ( - self.dp_axis, - self.pp_axis, - self.tp_axis, - self.sp_axis, - ) = ( - 0, - 1, - 2, - 3, - ) + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 @@ -1148,6 +1164,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, ) self.max_norm = max_norm @@ -1176,7 +1193,7 @@ class HybridParallelPlugin(PipelinePluginBase): return True def support_lora(self) -> bool: - return False + return True def control_checkpoint_io(self) -> bool: return True @@ -1210,6 +1227,7 @@ class HybridParallelPlugin(PipelinePluginBase): and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all" ) + # sync gradients across DP * SP ranks if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) else: @@ -1224,6 +1242,7 @@ class HybridParallelPlugin(PipelinePluginBase): use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, + overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: @@ -1306,7 +1325,7 @@ class HybridParallelPlugin(PipelinePluginBase): # so we disable it, performing manual reduction instead. ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - with ctx: + with ctx, model._wait_all_gather(): outputs = self.schedule.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) @@ -1362,15 +1381,15 @@ class HybridParallelPlugin(PipelinePluginBase): kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in `DataLoader `_. - Returns: + Returns:` :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() distributed_sampler_cls = distributed_sampler_cls or DistributedSampler sampler = distributed_sampler_cls( dataset, - num_replicas=self.pg_mesh.size(self.dp_axis), - rank=self.pg_mesh.coordinate(self.dp_axis), + num_replicas=self.dp_group.size(), + rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), shuffle=shuffle, ) @@ -1402,6 +1421,24 @@ class HybridParallelPlugin(PipelinePluginBase): return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() def enable_lora( - self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> Module: - raise NotImplementedError + from peft import PeftModel, get_peft_model + + assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + assert self.pp_size == 1 and self.tp_size == 1 + self.lora_enabled = True + warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index bfb8930bb..64f264f7e 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -2,6 +2,7 @@ import enum import logging import os import warnings +from contextlib import nullcontext from functools import partial from pathlib import Path from types import MethodType @@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.convert_fn = None if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) + self.overlap_allgather = overlap_allgather + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) + ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + with ctx: + return super().forward(*args, **kwargs) + + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): @@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() super().load_unsharded_model(model, checkpoint, strict) model.update_master_params() @@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): load_sub_module: bool = True, ): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + return super().save_sharded_model( + model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): from peft import PeftModel assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() peft_model = model.unwrap() assert isinstance( peft_model, PeftModel @@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase): reduce_bucket_size_in_m: int = 12, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, + overlap_allgather: bool = False, cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, @@ -316,6 +357,7 @@ class LowLevelZeroPlugin(DPPluginBase): partition_grad=(stage == 2), cpu_offload=cpu_offload, master_weights=master_weights, + overlap_allgather=overlap_allgather, fp8_communication=fp8_communication, ) self.lora_enabled = False @@ -406,7 +448,7 @@ class LowLevelZeroPlugin(DPPluginBase): group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: warnings.warn( - "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." ) elif ( check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED @@ -433,7 +475,9 @@ class LowLevelZeroPlugin(DPPluginBase): self.add_lora_params_to_optimizer(model, optimizer) if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel(model, self.precision) + model = LowLevelZeroModel( + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + ) # TODO: Support Galore + ZeRO zero_stage = self.stage diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 2cfdd000a..b3415af0e 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,9 +1,8 @@ -import random import warnings +from collections import defaultdict from types import MethodType -from typing import Callable, Optional, OrderedDict, Tuple +from typing import Callable, List, Optional, OrderedDict, Tuple -import numpy as np import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -11,34 +10,42 @@ from torch.nn import Module from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from colossalai.booster.plugin.hybrid_parallel_plugin import ( + PRECISION_TORCH_TYPE, + SUPPORT_SP_MODE, HybridParallelAMPOptimizer, HybridParallelModule, HybridParallelNaiveOptimizer, HybridParallelPlugin, + HybridParallelZeroOptimizer, get_param_info, - init_pipeline_optimizer, + reinitialize_optimizer, ) from colossalai.checkpoint_io import MoECheckpointIO -from colossalai.cluster import ProcessGroupMesh +from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.logging import get_dist_logger -from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.interface.optimizer import DistributedOptim +from colossalai.nn.optimizer import cast_to_distributed +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig +from colossalai.shardformer.shard.shard_config import ShardConfig from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.zero.low_level import LowLevelZeroOptimizer -class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): +class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): def __init__( self, optimizer: Optimizer, model: Module, use_pipeline: bool, + dp_process_group: Optional[ProcessGroup], # the dp pg for comm + tp_process_group: Optional[ProcessGroup], # if using tp + pp_process_group: Optional[ProcessGroup], # if using pp + moe_dp_group: ProcessGroup, # moe dp pg for comm param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, @@ -51,37 +58,25 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): verbose: bool = False, reduce_bucket_size: int = 1024 * 1024, # communication communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, + overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, - moe_extra_dp_process_group: Optional[ProcessGroup] = None, + overlap_allgather: bool = False, ): - self.param_info = param_info - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.dp_pg = dp_process_group - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - if use_pipeline: - init_pipeline_optimizer(optimizer, model) - pg_param_list = { - dp_process_group: [], - moe_extra_dp_process_group: [], + dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), + moe_dp_group: list(filter(is_moe_tensor, model.parameters())), } - for param in model.parameters(): - if is_moe_tensor(param): - pg_param_list[moe_extra_dp_process_group].append(param) - else: - pg_param_list[dp_process_group].append(param) + + if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: + raise ValueError("No parameters found in dp_process_group or moe_dp_group") super().__init__( + model=model, optimizer=optimizer, - pg_to_param_list=pg_param_list, + use_pipeline=use_pipeline, + param_info=param_info, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -96,30 +91,37 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): overlap_communication=overlap_communication, partition_grad=partition_grad, cpu_offload=cpu_offload, + tp_process_group=tp_process_group, + pp_process_group=pp_process_group, forced_dtype=forced_dtype, + pg_to_param_list=pg_param_list, + overlap_allgather=overlap_allgather, ) class MoeHybridParallelPlugin(HybridParallelPlugin): """ - Plugin for Moe Hybrid Parallel Training. + Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import HybridParallelPlugin + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import MoeHybridParallelPlugin - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + model, train_dataset, optimizer, criterion = ... + plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2) - >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` Args: - pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + ep_size (int): The size of expert parallelism + sp_size (int): The size of sequence parallelism. precision (str, optional): Specifies the precision of parameters during training. Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. Defaults to 'fp16'. @@ -132,7 +134,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. @@ -155,15 +159,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism """ def __init__( self, + tp_size: int, pp_size: int, ep_size: int, - tp_size: int = 1, - sp_size: int = 1, + sp_size: int = None, precision: str = "fp16", zero_stage: int = 0, enable_all_optimization: bool = False, @@ -171,7 +181,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, enable_sequence_overlap: bool = False, + parallel_output: bool = True, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -191,27 +203,61 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - use_ep_inside: bool = True, + overlap_communication: bool = False, custom_policy: Policy = None, - checkpoint_io: Optional[MoECheckpointIO] = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, + enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, + moe_dp_outside: bool = True, + overlap_p2p: bool = True, + overlap_allgather: bool = False, ) -> None: - world_size = dist.get_world_size() - assert tp_size == 1, "Tensor parallel is not supported in MoE yet" - assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet" + if overlap_communication or zero_stage == 2: + overlap_communication = False + zero_stage = 1 + warnings.warn( + f"overlap_communication and zero_stage are set to False and 1 because " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + ) assert ( - world_size % (tp_size * pp_size) == 0 - ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}" - assert ( - world_size % (tp_size * pp_size * ep_size) == 0 - ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + if enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all"]: + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"You should not set sp_size when sequence parallelism is not enabled." + self.sp_size = 1 - self.dp_size = world_size // (tp_size * pp_size) + assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}" + self.moe_dp_size = self.dp_size // ep_size + self.ep_size = ep_size self.tp_size = tp_size self.pp_size = pp_size - self.ep_size = ep_size - self.sp_size = sp_size self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -220,61 +266,69 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.checkpoint_io = checkpoint_io - - logger = get_dist_logger() - - # NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param - # See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient - # we change pg mesh to (pp, dp, tp) for better moe performance - assert ( - self.ep_size <= self.dp_size - ), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})." - - self.moe_dp_size = self.dp_size // self.ep_size - self.use_ep_inside = use_ep_inside - if self.use_ep_inside: - logger.info(f"MoE Parallel use ep inside dp.", ranks=[0]) - self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size) + if moe_dp_outside: + self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 + self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size) else: - logger.info(f"MoE Parallel use ep outside dp.", ranks=[0]) - warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.") - self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3 - self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size) + self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) - self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) - self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) - logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0]) - logger.info( - f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0] - ) - - self.tp_group = self.pg_mesh.get_group_along_axis( - self.tp_axis - ) # TODO: support custom tp size for mixtral lm head - self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis)) - self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) - # TODO: Currently moe only support partially sequence parallel - self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - - self.custom_policy = custom_policy self.stage_manager = None self.schedule = None - + self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" - assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis) - self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=self.pp_axis, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, ) + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + ) + else: + raise NotImplementedError() + + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis]) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis) + self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, + ep_group=self.ep_group, + moe_dp_group=self.moe_dp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, @@ -282,8 +336,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, - ep_group=self.ep_group, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, + gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( initial_scale=initial_scale, @@ -310,77 +367,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_communication=overlap_communication, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, ) self.max_norm = max_norm - def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs - ): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler( - dataset, - num_replicas=self.dp_size, - rank=dist.get_rank(self.global_dp_group), - shuffle=shuffle, - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) - def get_checkpoint_io(self) -> MoECheckpointIO: - if self.checkpoint_io is None: - self.checkpoint_io = MoECheckpointIO( - self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage - ) - else: - self.checkpoint_io = self.checkpoint_io( - self.global_dp_group, - self.pp_group, - self.tp_group, - ep_group=self.ep_group, - moe_dp_group=self.moe_dp_group, - zero_stage=self.zero_stage, - ) - if hasattr(self.checkpoint_io, "moe_info"): - self.checkpoint_io.moe_info = self.moe_info - return self.checkpoint_io + return MoECheckpointIO( + self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage + ) def configure( self, @@ -391,13 +387,40 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + # Replace with distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + if not isinstance(model, ModelWrapper): - use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if use_ddp: + warnings.warn( + f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + ) + self.ddp_config["find_unused_parameters"] = True + + if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + raise ValueError( + f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0" + ) + + # sync gradients across DP * SP ranks + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) + else: + dp_group = self.dp_group + model = HybridParallelModule( module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.global_dp_group, + dp_group=dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -405,7 +428,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): custom_policy=self.custom_policy, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.ep_size > 1: + # if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups + # but the optimizer is not aware of ep, so we need to update the optimizer + reinitialize_optimizer(optimizer, model) + if self.zero_stage == 0: + is_zero = False if self.precision in ["fp16", "bf16"]: optimizer = HybridParallelAMPOptimizer( optimizer, @@ -418,20 +447,30 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) else: optimizer = HybridParallelNaiveOptimizer( - optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, ) else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + if self.dp_size <= 1: + warnings.warn( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you do not intend to use cpu_offload, please consider set zero_stage=0." + ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.global_dp_group, + dp_process_group=dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, - moe_extra_dp_process_group=self.moe_dp_group, + moe_dp_group=self.moe_dp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, @@ -440,4 +479,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 61c9d1438..0310df548 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() model = model.unwrap() if os.path.isfile(checkpoint): @@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): This argument should be manually set to False since params on same device might be stored in different files. """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() model = model.unwrap() if self.dp_rank != 0: @@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model._force_wait_all_gather() strict = False model_before_wrapping = model model = model.unwrap() @@ -943,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_[k] = v.detach().clone().to(device) return state_ + + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index a0b625008..9181956b7 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -151,13 +151,10 @@ class MoECheckpointIO(HybridParallelCheckpointIO): # ep_rank 0 saves all the parameters and buffers. # other ep_ranks save only experts - ep_param_pattern = "experts." if self.ep_rank != 0 else None # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = MoECheckpointIO._model_sharder( - model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern - ) + state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) control_saving = self.tp_rank == 0 diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py index 98191747e..14a8eabb4 100644 --- a/colossalai/cluster/dist_coordinator.py +++ b/colossalai/cluster/dist_coordinator.py @@ -44,7 +44,7 @@ class DistCoordinator(metaclass=SingletonMeta): self._rank = dist.get_rank() self._world_size = dist.get_world_size() # this is often passed by launchers such as torchrun - self._local_rank = os.environ.get("LOCAL_RANK", -1) + self._local_rank = int(os.environ.get("LOCAL_RANK", -1)) @property def rank(self) -> int: diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1319a4529..dc96708f0 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import GroupMember def prod(nums: List[int]) -> int: @@ -47,7 +48,7 @@ class ProcessGroupMesh: self._shape = size self._rank = dist.get_rank() self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) - self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} + self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} def destroy_mesh_process_groups(self): @@ -136,7 +137,7 @@ class ProcessGroupMesh: assert mode in ["raise", "wrap", "clip"] return int(np.ravel_multi_index(coord, shape, mode)) - def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: + def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. Args: @@ -147,10 +148,11 @@ class ProcessGroupMesh: ProcessGroup: The process group with the given ranks. """ ranks_in_group = sorted(ranks_in_group) - if tuple(ranks_in_group) not in self._group_to_ranks: + if tuple(ranks_in_group) not in self._ranks_to_group: group = dist.new_group(ranks_in_group, backend=backend) self._ranks_to_group[tuple(ranks_in_group)] = group - self._group_to_ranks[group] = tuple(ranks_in_group) + if group is not GroupMember.NON_GROUP_MEMBER: + self._group_to_ranks[group] = tuple(ranks_in_group) return self._ranks_to_group[tuple(ranks_in_group)] def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: @@ -238,7 +240,7 @@ class ProcessGroupMesh: for base_coord in itertools.product(*[range(s) for s in reduced_shape]): coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) - group = self.get_group(ranks_in_group, backend=backend) + group = self._get_group(ranks_in_group, backend=backend) if self._rank in ranks_in_group: target_group = group return target_group diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 0a9b5293d..76813a4a3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -18,7 +18,7 @@ ## 📌 Introduction -ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference) +ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)

@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below. journal={arXiv}, year={2023} } + +# Distrifusion +@InProceedings{Li_2024_CVPR, + author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song}, + title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month={June}, + year={2024}, + pages={7183-7193} +} ``` diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index e114e8a61..072ddbcfd 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified import logging from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers.generation import GenerationConfig @@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM): enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. start_token_size(int): The size of the start tokens, when using StreamingLLM. generated_token_size(int): The size of the generated tokens, When using StreamingLLM. + patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion """ # NOTE: arrange configs according to their importance and frequency of usage @@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM): start_token_size: int = 4 generated_token_size: int = 512 + # Acceleration for Diffusion Model(PipeFusion or Distrifusion) + patched_parallelism_size: int = 1 # for distrifusion + # pipeFusion_m_size: int = 1 # for pipefusion + # pipeFusion_n_size: int = 1 # for pipefusion + def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len self._verify_config() @@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM): # Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit. self.start_token_size = self.block_size + # check Distrifusion + # TODO(@lry89757) need more detailed check + if self.patched_parallelism_size > 1: + # self.use_patched_parallelism = True + self.tp_size = ( + self.patched_parallelism_size + ) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size + # check prompt template if self.prompt_template is None: return @@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM): use_cuda_kernel=self.use_cuda_kernel, use_spec_dec=self.use_spec_dec, use_flash_attn=use_flash_attn, + patched_parallelism_size=self.patched_parallelism_size, ) return model_inference_config @@ -396,3 +411,50 @@ class ModelShardInferenceConfig: use_cuda_kernel: bool = False use_spec_dec: bool = False use_flash_attn: bool = False + patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique + + +@dataclass +class DiffusionGenerationConfig: + """ + Param for diffusion model forward + """ + + prompt_2: Optional[Union[str, List[str]]] = None + prompt_3: Optional[Union[str, List[str]]] = None + height: Optional[int] = None + width: Optional[int] = None + num_inference_steps: int = None + timesteps: List[int] = None + guidance_scale: float = None + negative_prompt: Optional[Union[str, List[str]]] = ( + None # NOTE(@lry89757) in pixart default to "", in sd3 default to None + ) + negative_prompt_2: Optional[Union[str, List[str]]] = None + negative_prompt_3: Optional[Union[str, List[str]]] = None + num_images_per_prompt: Optional[int] = None + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None + latents: Optional[torch.FloatTensor] = None + prompt_embeds: Optional[torch.FloatTensor] = None + negative_prompt_embeds: Optional[torch.FloatTensor] = None + pooled_prompt_embeds: Optional[torch.FloatTensor] = None + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None + output_type: Optional[str] = None # "pil" + return_dict: bool = None + joint_attention_kwargs: Optional[Dict[str, Any]] = None + clip_skip: Optional[int] = None + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None + callback_on_step_end_tensor_inputs: List[str] = None + + def to_dict(self) -> Dict[str, Any]: + # NOTE(@lry89757) Only return the dict that not the default value None + result = {} + for field in fields(self): + value = getattr(self, field.name) + if value is not None: + result[field.name] = value + return result + + @classmethod + def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig": + return cls(**kwargs) diff --git a/colossalai/inference/core/base_engine.py b/colossalai/inference/core/base_engine.py new file mode 100644 index 000000000..392dd2990 --- /dev/null +++ b/colossalai/inference/core/base_engine.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + + +class BaseEngine(ABC): + @abstractmethod + def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None): + pass + + @abstractmethod + def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None): + """ + Init Model for Engine + """ + + @abstractmethod + def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs): + """ + Generate ouptput for coming requests + """ + + @abstractmethod + def add_request(self, prompts, request_ids=None, **kwargs): + """ + Add new request to Engine + """ + + @abstractmethod + def step(self): + """ + Perform one new step forward + """ + + @abstractmethod + def _verify_args(self): + """ + Verify the parameters and members of class + """ + + @torch.inference_mode() + def capture_model(self): + """ + Use cuda graph to capture model + """ + return NotImplementedError("This method should be implemented by subclasses") + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + model_shard_infer_config: ModelShardInferenceConfig = None, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + **kwargs, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: The model optimized by Shardformer. + """ + + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs}, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py new file mode 100644 index 000000000..8bed508cb --- /dev/null +++ b/colossalai/inference/core/diffusion_engine.py @@ -0,0 +1,200 @@ +from itertools import count +from typing import List, Tuple, Type, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from torch import distributed as dist + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.struct import DiffusionSequence +from colossalai.inference.utils import get_model_size, get_model_type +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .base_engine import BaseEngine +from .request_handler import NaiveRequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + + +class DiffusionEngine(BaseEngine): + def __init__( + self, + model_or_path: DiffusionPipeline | str, + inference_config: InferenceConfig = None, + verbose: bool = False, + model_policy: Policy | type[Policy] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + + self.model_type = get_model_type(model_or_path=model_or_path) + + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) + + self.request_handler = NaiveRequestHandler() + + self.counter = count() + + self._verify_args() + + def _verify_args(self) -> None: + assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe" + + def init_model( + self, + model_or_path: Union[str, nn.Module, DiffusionPipeline], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. + """ + if isinstance(model_or_path, str): + model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype) + policy_map_key = model.__class__.__name__ + model = DiffusionPipe(model) + elif isinstance(model_or_path, DiffusionPipeline): + policy_map_key = model_or_path.__class__.__name__ + model = DiffusionPipe(model_or_path) + else: + self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!") + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + model_policy = model_policy_map.get(policy_map_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + model_shard_infer_config, + None, + tp_group=tp_group, + ) + + self.model = model.to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + free_gpu_memory, _ = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + generation_config: DiffusionGenerationConfig = None, + **kwargs, + ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: + """ """ + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + **gen_config_dict, + **kwargs, + ) + + output_reqs_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + while self.request_handler.check_unfinished_reqs(): + output_reqs_list += self.step() + + return output_reqs_list + + def add_request( + self, + prompts: Union[List[str], str], + request_ids: Union[List[int], int] = None, + **kwargs, + ): + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if not isinstance(prompts, list): + prompts = [prompts] + + generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs) + prompts_num = len(prompts) + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + + seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config) + + self.request_handler.add_sequence(seq) + + def step(self) -> List[PIL.Image.Image]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. run forward to get List[Image] + Returns: + List[PIL.Image.Image]: Image Generated by one step. + """ + + input = self.request_handler.schedule() + ret = self.model(prompt=input.prompt, **input.generation_config.to_dict()) + return ret diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 8f8aef65e..5c9bdc321 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,57 +1,24 @@ -import time -from itertools import count -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import List, Tuple, Type, Union import numpy as np -import torch +import PIL.Image import torch.nn as nn -from torch import distributed as dist -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - PreTrainedTokenizer, - PreTrainedTokenizerFast, -) -from transformers.models.llama.modeling_llama import LlamaForCausalLM +from diffusers import DiffusionPipeline +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from colossalai.accelerator import get_accelerator -from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig -from colossalai.inference.graph_runner import CUDAGraphRunner -from colossalai.inference.modeling.policy import model_policy_map -from colossalai.inference.sampler import search_tokens -from colossalai.inference.spec import Drafter, GlideInput -from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size, has_index_file -from colossalai.interface import ModelWrapper -from colossalai.lazy import LazyInitContext -from colossalai.logging import get_dist_logger -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.config import InferenceConfig +from colossalai.inference.utils import ModelType, get_model_type from colossalai.shardformer.policies.base_policy import Policy -from .request_handler import RequestHandler - __all__ = ["InferenceEngine"] -PP_AXIS, TP_AXIS = 0, 1 - -_supported_models = { - "LlamaForCausalLM": LlamaForCausalLM, - "BaichuanForCausalLM": AutoModelForCausalLM, -} - -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] - class InferenceEngine: """ InferenceEngine which manages the inference process.. Args: - model_or_path (nn.Module or str): Path or nn.Module of this model. + model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. @@ -60,567 +27,68 @@ class InferenceEngine: def __init__( self, - model_or_path: Union[nn.Module, str], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - inference_config: InferenceConfig, + model_or_path: Union[nn.Module, str, DiffusionPipeline], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, + inference_config: InferenceConfig = None, verbose: bool = False, model_policy: Union[Policy, Type[Policy]] = None, ) -> None: - self.inference_config = inference_config - self.dtype = inference_config.dtype - self.high_precision = inference_config.high_precision + self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__ + self.model_type = get_model_type(model_or_path=model_or_path) + self.engine = None + if self.model_type == ModelType.LLM: + from .llm_engine import LLMEngine - self.verbose = verbose - self.logger = get_dist_logger(__name__) - self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + self.engine = LLMEngine( + model_or_path=model_or_path, + tokenizer=tokenizer, + inference_config=inference_config, + verbose=verbose, + model_policy=model_policy, + ) + elif self.model_type == ModelType.DIFFUSION_MODEL: + from .diffusion_engine import DiffusionEngine - self.init_model(model_or_path, model_policy, self.model_shard_infer_config) - - self.generation_config = inference_config.to_generation_config(self.model_config) - self.generation_config_dict = self.generation_config.to_dict() - - self.tokenizer = tokenizer - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.request_handler = RequestHandler(self.inference_config, self.model_config) - self.k_cache, self.v_cache = self.request_handler.get_kvcache() - # DISCUSS maybe move this into batch info? - - self.counter = count() - - self.use_cuda_graph = self.inference_config.use_cuda_graph - if self.use_cuda_graph: - self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool = None # Set during graph capture. - if verbose: - self.logger.info("Colossal AI CUDA Graph Capture on") - - self.capture_model(self.k_cache, self.v_cache) - - # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = self.inference_config.use_spec_dec - - self.drafter_model = None - self.drafter = None - self.use_glide = False - self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.engine = DiffusionEngine( + model_or_path=model_or_path, + inference_config=inference_config, + verbose=verbose, + model_policy=model_policy, + ) + elif self.model_type == ModelType.UNKNOWN: + self.logger.error(f"Model Type either Difffusion or LLM!") + self._initialized = True self._verify_args() - def init_model( - self, - model_or_path: Union[nn.Module, str], - model_policy: Union[Policy, Type[Policy]] = None, - model_shard_infer_config: ModelShardInferenceConfig = None, - ): - """ - Shard model or/and Load weight - - Args: - model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. - model_policy (Policy): the policy to replace the model. - model_inference_config: the configuration for modeling initialization when inference. - model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. - """ - pretrained_path = None - if isinstance(model_or_path, str): - import colossalai.interface.pretrained as pretrained_utils - - try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) - arch = getattr(hf_config, "architectures")[0] - if arch in _supported_models.keys(): - if arch is "BaichuanForCausalLM": - self.logger.warning( - "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" - ) - ctx = LazyInitContext(default_device="cuda") - with ctx: - model = _supported_models[arch].from_pretrained( - model_or_path, trust_remote_code=True, torch_dtype=self.dtype - ) - pretrained_path = pretrained_utils.get_pretrained_path(model) - else: - # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate - raise ValueError(f"Model {arch} is not supported.") - - except Exception as e: - self.logger.error( - f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" - ) - else: - model = model_or_path - - self.model_config = model.config - - torch.cuda.empty_cache() - init_gpu_memory = torch.cuda.mem_get_info()[0] - - self.device = get_accelerator().get_current_device() - if self.verbose: - self.logger.info(f"the device is {self.device}") - - model = model.to(self.dtype).eval() - - if self.verbose: - self.logger.info( - f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" - ) - - if model_policy is None: - prefix = "nopadding" if not self.inference_config.pad_input else "padding" - model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" - model_policy = model_policy_map.get(model_policy_key) - - if not isinstance(model_policy, Policy): - try: - model_policy = model_policy() - except Exception as e: - raise ValueError(f"Unable to instantiate model policy: {e}") - - assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" - pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) - tp_group = pg_mesh.get_group_along_axis(TP_AXIS) - - self.model = self._shardformer( - model, - model_policy, - model_shard_infer_config, - None, - tp_group=tp_group, - ) - - self.model = ModelWrapper(model).to(self.device) - - if self.verbose: - self.logger.info( - f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" - ) - - if pretrained_path: - from colossalai.inference.core.plugin import InferCheckpoint_io - - cpt_io = InferCheckpoint_io() - if_has_index_file, model_index_file = has_index_file(pretrained_path) - assert if_has_index_file, "the model path is invalid" - cpt_io.load_model(self.model, model_index_file) - - free_gpu_memory, _ = torch.cuda.mem_get_info() - peak_memory = init_gpu_memory - free_gpu_memory - if self.verbose: - self.logger.info( - f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" - ) - - @torch.inference_mode() - def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): - assert self.use_cuda_graph, "please turn on the cuda graph" - - if self.verbose: - self.logger.info("Colossal AI CUDA Graph Capture begin") - - t_capture_begin = time.perf_counter() - - block_size = self.inference_config.block_size - head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - max_context_len_to_capture = self.inference_config.max_context_len_to_capture - max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size - input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() - # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) - self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) - self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) - self.graph_block_tables[0, :] = np.arange( - 0, max_num_blocks - ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len - block_tables = torch.from_numpy(self.graph_block_tables).cuda() - output_tensor = torch.zeros( - (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device - ) - fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor - - max_num_seqs = self.inference_config.max_batch_size - batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] - sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() - # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len - sequence_lengths[0] = torch.tensor( - self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 - ).cuda() - - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - if self.verbose: - self.logger.info(f"batch size {batch_size} graph capturing") - - input_meta_data = InputMetaData( - block_tables=block_tables[:batch_size], - sequence_lengths=sequence_lengths[:batch_size], - fd_inter_tensor=fd_inter_tensor, - batch_size=batch_size, - is_prompts=False, - use_cuda_graph=True, - high_precision=False, - kv_seq_len=sequence_lengths[:batch_size].max().item(), - head_dim=head_dim, - dtype=self.dtype, - ) - - graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( - input_tokens_ids[:batch_size], - output_tensor[:batch_size], - input_meta_data, - k_caches=k_cache, - v_caches=v_cache, - memory_pool=self.graph_memory_pool, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner - - t_capture_end = time.perf_counter() - - if self.verbose: - self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") - def _verify_args(self) -> None: """Verify the input args""" - if not isinstance(self.inference_config, InferenceConfig): - raise TypeError("Invalid type of inference config provided.") - if not isinstance(self.model, nn.Module): - raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") - if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): - raise TypeError( - f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" - ) - if isinstance(self.model, ModelWrapper): - model = self.model.module - assert ( - model.__class__.__name__ in _supported_models.keys() - ), f"Model {self.model.__class__.__name__} is not supported." - - def _shardformer( - self, - model: nn.Module, - model_policy: Policy, - model_shard_infer_config: ModelShardInferenceConfig = None, - stage_manager: PipelineStageManager = None, - tp_group: ProcessGroupMesh = None, - ) -> nn.Module: - """ - Initialize ShardConfig and replace the model with shardformer. - - Args: - model (nn.Module): Path or nn.Module of this model. - model_policy (Policy): The policy to shardformer model which is determined by the model type. - stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. - tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. - - Returns: - nn.Module: The model optimized by Shardformer. - """ - - shardconfig = ShardConfig( - tensor_parallel_process_group=tp_group, - pipeline_stage_manager=stage_manager, - enable_tensor_parallelism=(self.inference_config.tp_size > 1), - enable_fused_normalization=False, - enable_all_optimization=False, - enable_flash_attention=False, - enable_jit_fused=False, - enable_sequence_parallelism=False, - extra_kwargs={"model_shard_infer_config": model_shard_infer_config}, - ) - shardformer = ShardFormer(shard_config=shardconfig) - shard_model, _ = shardformer.optimize(model, model_policy) - return shard_model - - def enable_spec_dec( - self, - drafter_model: nn.Module = None, - n_spec_tokens: int = None, - use_glide_drafter: bool = False, - ) -> None: - """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. - - Args: - drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. - If provided, the previous drafter and drafter model, if exist, will be overwritten. - n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. - If not provided, `max_n_spec_tokens` in InferenceConfig will be used. - use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. - If True, the drafter model will be replaced by a glide model. - - ```python - ... - engine = InferenceEngine(model, tokenizer, inference_config) - - engine.enable_spec_dec(drafter_model, n_spec_tokens=5) - engine.generate(...) # Speculative Decoding - - engine.disable_spec_dec() - engine.generate(...) # Normal generation - - engine.enable_spec_dec() - engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens - engine.clear_spec_dec() - ``` - """ - - if drafter_model is None and self.drafter is None: - raise ValueError("Drafter not initialized. Please provide a Drafter Model") - if n_spec_tokens is not None: - assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens - self.n_spec_tokens = n_spec_tokens - if drafter_model is not None: - assert isinstance(drafter_model, nn.Module) - # overwrite the drafter, if exists - self.clear_spec_dec() - self.drafter_model = drafter_model - self.drafter = Drafter( - self.drafter_model, - self.tokenizer, - device=self.device, - dtype=self.dtype, - ) - - # check if the provided drafter model is compatible with GLIDE structure - # when `use_glide_drafter` is set to True - if ( - use_glide_drafter - and hasattr(drafter_model, "model") - and hasattr(drafter_model.model, "layers") - and hasattr(drafter_model.model.layers[0], "cross_attn") - ): - self.use_glide = use_glide_drafter - elif use_glide_drafter: - self.logger.warning( - f"`use_glide_drafter` is provided as {use_glide_drafter}, " - f"but the provided drafter model is not compatible with GLIDE structure." - f"Falling back to use the default drafter model (non-GLIDE)." - ) - self.request_handler.set_spec_dec_mode(self.n_spec_tokens) - # using speculative decoding for subsequent generations - self.use_spec_dec = True - - def disable_spec_dec(self) -> None: - """Disable using speculative decoding for subsequent generations.""" - self.request_handler.unset_spec_dec_mode() - # set back to the maximum number of tokens to speculate - self.n_spec_tokens = self.inference_config.max_n_spec_tokens - self.use_glide = False - self.use_spec_dec = False - - def clear_spec_dec(self) -> None: - """Clear relatable structures of speculative decoding, if exist.""" - if self.use_spec_dec: - self.disable_spec_dec() - if self.drafter_model or self.drafter: - self.drafter_model = None - self.drafter = None - torch.cuda.empty_cache() - self.use_glide = False - self.use_spec_dec = False - - def steps_spec_dec(self) -> List[Sequence]: - """ - Run Speculative Decoding steps. This is like retrieving a single batch and launch inference - with many steps of speculating by a drafter model as well as verifying by a main model. - - Returns: - List[Sequence]: finished sequences generated by one step. - """ - batch = self.request_handler.schedule() # prefill batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] - else: - model_executable = self.model - - # 1. Prefill small model (Drafter) - fill past kv cache for drafter model - # NOTE For glide drafter models, we won't actually apply glide during prefill stage - drafter_out = self.drafter.speculate(input_token_ids, 1, None) - next_token_ids_spec = drafter_out.next_tokens - drafter_past_key_values = drafter_out.past_key_values - - # 2. Prefill main model (Verifier) - fill past kv cache for main model - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) - # append new inputs to the batch, temporarily - batch.append_batch_tokens(next_tokens) - self.request_handler.allocate_batch_spec_dec(batch, 1) - already_allocated_kv_len = batch.seq_lengths[0].item() - input_token_ids = batch.get_1D_inputs_spec_dec(1) - - finished_sequences = self.request_handler.update() - - while True: - # HACK Retrieve the running batch - # Using RequestHandler.schedule here will re-allocate same kv cache for the batch - batch = self.request_handler.running_bb # running batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - - # 3. Decoding - Drafter model speculates `n` tokens - glide_input = None - if self.use_glide: - glide_input = GlideInput( - batch.get_block_table_tensor(), - self.k_cache[-1], # use kv cahces of the last layer - self.v_cache[-1], - batch.get_sequence_lengths(), - n_spec_tokens=self.n_spec_tokens, - ) - - drafter_out = self.drafter.speculate( - input_token_ids, - self.n_spec_tokens, - drafter_past_key_values, - glide_input=glide_input, - ) - next_token_ids_spec = drafter_out.next_tokens - drafter_past_key_values = drafter_out.past_key_values - drafter_spec_length = drafter_out.speculated_length - - for next_token_id_spec in next_token_ids_spec: - self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) - cur_length = batch.seq_lengths[0].item() - if already_allocated_kv_len < cur_length: - self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) - already_allocated_kv_len = cur_length - - # 4. Decoding - Main model verifies `n` tokens in parallel - if drafter_spec_length < batch.num_tokens_to_verify: - batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - - next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) - - # 5. Compare and process the results - diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) - n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() - - # revoke appended tokens for each Sequence in the current batch - batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens - - # append the last correct token generated by the main model - self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) - - # trim past key values of the drafter model - drafter_past_key_values = Drafter.trim_kv_cache( - drafter_past_key_values, drafter_spec_length - n_matches - 1 - ) - - # prepare inputs for the next round of speculation - n = 1 if n_matches < drafter_spec_length else 2 - input_token_ids = batch.get_1D_inputs_spec_dec(n) - - self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) - finished_sequences = self.request_handler.update() - if len(finished_sequences) > 0: - break - - # Reset back the number of speculated tokens of the batch, - # this is used to handle the last round of speculation, in which case the number of speculated tokens - # by the drafter is less than the number of speculated tokens set to the engine. - batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) - - return finished_sequences + assert self.engine is not None, "Please init Engine first" + assert self._initialized, "Engine must be initialized" def generate( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, - prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, - return_token_ids: bool = False, - generation_config: Optional[GenerationConfig] = None, - ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: + *args, + **kwargs, + ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. - prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. - return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. - generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. - - Returns: - Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. """ - gen_config_dict = generation_config.to_dict() if generation_config is not None else {} - prompts = [prompts] if isinstance(prompts, str) else prompts - request_ids = [request_ids] if isinstance(request_ids, int) else request_ids - - with torch.inference_mode(): - if prompts is not None or prompts_token_ids is not None: - self.add_request( - request_ids=request_ids, - prompts=prompts, - prompts_token_ids=prompts_token_ids, - **gen_config_dict, - ) - - output_seqs_list = [] - total_tokens_list = [] - - # intuition: If user provide a generation config, we should replace the existing one. - if generation_config is not None: - self.generation_config = generation_config - self.generation_config_dict = gen_config_dict - - if self.use_spec_dec: - assert self.drafter is not None, "Drafter Model is not initialized." - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.steps_spec_dec() - else: - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.step() - - output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) - - for seq in output_seqs_list: - total_tokens_list.append(seq.input_token_id + seq.output_token_id) - - output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) - - if return_token_ids: - output_tokens_list = [seq.output_token_id for seq in output_seqs_list] - return output_str, output_tokens_list - else: - return output_str - - @property - def has_prompt_template(self) -> bool: - """ """ - return self.inference_config.prompt_template is not None - - def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: - """ - This method will format the input prompt according to the prompt template given to the InferenceConfig. - """ - assert ( - self.has_prompt_template - ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." - - if isinstance(prompts, (list, tuple)): - return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] - elif isinstance(prompts, str): - return self.inference_config.prompt_template.format(input_text=prompts) - else: - raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + assert self.engine is not None, "Please init Engine first" + return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs) def add_request( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, - prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + *args, **kwargs, ) -> None: """ @@ -630,168 +98,36 @@ class InferenceEngine: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + kwargs: for LLM, it could be max_length, max_new_tokens, etc + for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers """ + assert self.engine is not None, "Please init Engine first" + self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs) - # apply the prompt template to the input prompts + def step(self): + assert self.engine is not None, "Please init Engine first" + return self.engine.step() - if self.has_prompt_template and prompts is not None: - prompts = self.format_prompt(prompts) - - block_size = self.inference_config.block_size - - if request_ids is not None and not isinstance(request_ids, list): - request_ids = [request_ids] - - if prompts is not None and not isinstance(prompts, list): - prompts = [prompts] - - if prompts_token_ids is None: - assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." - prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ - "input_ids" - ] - - # list of torch Tensor - if isinstance(prompts_token_ids, list): - if isinstance(prompts_token_ids[0], torch.Tensor): - prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] - elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): - prompts_token_ids = prompts_token_ids.tolist() - else: - raise TypeError( - f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." - ) - - assert ( - len(prompts_token_ids[0]) <= self.inference_config.max_input_len - ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." - - prompts_num = len(prompts_token_ids) - - for i in range(prompts_num): - if request_ids: - assert isinstance( - request_ids[0], int - ), f"The request_id type must be int, but got {type(request_ids[0])}" - assert len(request_ids) == prompts_num - request_id = request_ids[i] + def __getattr__(self, name): + """ + The Design logic of getattr, setattr: + 1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine. + 2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx + So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine) + """ + if self.__dict__.get("_initialized", False): + if name in self.__dict__: + return self.__dict__[name] else: - request_id = next(self.counter) - if prompts == None: - prompt = None + return getattr(self.engine, name) + else: + return self.__dict__[name] + + def __setattr__(self, name, value): + if self.__dict__.get("_initialized", False): + if name in self.__dict__: + self.__dict__[name] = value else: - prompt = prompts[i] - - max_length = kwargs.get("max_length", None) - max_new_tokens = kwargs.get("max_new_tokens", None) - if max_length is None and max_new_tokens is None: - max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len - elif max_length is not None: - max_new_tokens = max_length - len(prompts_token_ids[i]) - - if not self.inference_config.enable_streamingllm: - assert ( - self.inference_config.max_output_len >= max_new_tokens - ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." - - sequence = Sequence( - request_id, - prompt, - prompts_token_ids[i], - block_size, - None, - self.tokenizer.eos_token_id, - self.tokenizer.pad_token_id, - max_output_len=max_new_tokens, - ignore_eos=self.inference_config.ignore_eos, - ) - self.request_handler.add_sequence(sequence) - - def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: - input_ids = batch.get_1D_inputs() - sequence_lengths = batch.get_sequence_lengths() - - if batch.is_prompts: - n_tokens = sequence_lengths.sum().item() + setattr(self.engine, name, value) else: - n_tokens = batch.current_batch_size - if batch.use_spec_dec: - n_tokens = batch.num_tokens_to_verify + 1 - assert n_tokens == input_ids.size(0) - n_tokens = n_tokens * batch.current_batch_size - output_tensor = torch.zeros( - (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - ) - - batch_token_ids = None - if ( - self.generation_config.repetition_penalty != 1.0 - or self.generation_config.no_repeat_ngram_size > 0 - or self.generation_config.forced_eos_token_id is not None - ): - batch_token_ids = batch.batch_token_ids - - # only when we have the graph for specific decoding batch size can we use the cuda graph for inference - use_cuda_graph = False - if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): - use_cuda_graph = True - - input_meta_data = InputMetaData( - block_tables=batch.get_block_table_tensor(), - sequence_lengths=sequence_lengths, - fd_inter_tensor=batch.fd_inter_tensor, - batch_size=batch.current_batch_size, - is_prompts=batch.is_prompts, - use_cuda_kernel=self.inference_config.use_cuda_kernel, - use_cuda_graph=use_cuda_graph, - high_precision=self.high_precision, - kv_seq_len=sequence_lengths.max().item(), - head_dim=batch.head_dim, - dtype=batch.dtype, - use_spec_dec=batch.use_spec_dec, - num_tokens_to_verify=batch.num_tokens_to_verify, - batch_token_ids=batch_token_ids, - ) - - return input_ids, output_tensor, input_meta_data - - def step(self) -> List[str]: - """ - In each step, do the follows: - 1. Run RequestHandler.schedule() and get the batch used for inference. - 2. Get the input, inputinfo and output placeholder from the batchbucket - 3. Run model to generate the next token - 4. Update waiting list and running list in RequestHandler and get finished sequences. - 5. Decode and return finished sequences. - - Returns: - List[str]: Decoded finished sequences generated by one step. - """ - - batch = self.request_handler.schedule() - - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] - else: - model_executable = self.model - - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - if self.inference_config.pad_input: - logits = logits[:, -1, :] - - if self.inference_config.enable_streamingllm: - updated_block_ids = batch.streamingllm_update_batch( - self.inference_config.start_token_size, self.inference_config.generated_token_size - ) - self.request_handler.streamingllm_free_block_tables(updated_block_ids) - - next_tokens = search_tokens( - self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids - ) - self.request_handler.append_next_tokens(next_tokens) - finished_sequences = self.request_handler.update() - - return finished_sequences + self.__dict__[name] = value diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py new file mode 100644 index 000000000..1dbc3ace8 --- /dev/null +++ b/colossalai/inference/core/llm_engine.py @@ -0,0 +1,758 @@ +import time +from itertools import count +from typing import Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig +from colossalai.inference.graph_runner import CUDAGraphRunner +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.sampler import search_tokens +from colossalai.inference.spec import Drafter, GlideInput +from colossalai.inference.struct import Sequence +from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .base_engine import BaseEngine +from .request_handler import RequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} + +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class LLMEngine(BaseEngine): + """ + InferenceEngine which manages the inference process.. + + Args: + model_or_path (nn.Module or str): Path or nn.Module of this model. + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. + verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. + """ + + def __init__( + self, + model_or_path: Union[nn.Module, str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, + inference_config: InferenceConfig = None, + verbose: bool = False, + model_policy: Union[Policy, type[Policy]] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) + + self.generation_config = inference_config.to_generation_config(self.model_config) + self.generation_config_dict = self.generation_config.to_dict() + + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.request_handler = RequestHandler(self.inference_config, self.model_config) + self.k_cache, self.v_cache = self.request_handler.get_kvcache() + # DISCUSS maybe move this into batch info? + + self.counter = count() + + self.use_cuda_graph = self.inference_config.use_cuda_graph + if self.use_cuda_graph: + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + if verbose: + self.logger.info("Colossal AI CUDA Graph Capture on") + + self.capture_model(self.k_cache, self.v_cache) + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = self.inference_config.use_spec_dec + + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self._verify_args() + + def init_model( + self, + model_or_path: Union[nn.Module, str], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. + """ + pretrained_path = None + if isinstance(model_or_path, str): + import colossalai.interface.pretrained as pretrained_utils + + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) + arch = getattr(hf_config, "architectures")[0] + if arch in _supported_models.keys(): + if arch == "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _supported_models[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) + else: + # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate + raise ValueError(f"Model {arch} is not supported.") + + except Exception as e: + self.logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + model = model.to(self.dtype).eval() + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + prefix = "nopadding" if not self.inference_config.pad_input else "padding" + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" + model_policy = model_policy_map.get(model_policy_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + model_shard_infer_config, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io + + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, _ = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + @torch.inference_mode() + def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): + assert self.use_cuda_graph, "please turn on the cuda graph" + + if self.verbose: + self.logger.info("Colossal AI CUDA Graph Capture begin") + + t_capture_begin = time.perf_counter() + + block_size = self.inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_context_len_to_capture = self.inference_config.max_context_len_to_capture + max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size + input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() + # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) + self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) + self.graph_block_tables[0, :] = np.arange( + 0, max_num_blocks + ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + output_tensor = torch.zeros( + (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device + ) + fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor + + max_num_seqs = self.inference_config.max_batch_size + batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] + sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() + # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + sequence_lengths[0] = torch.tensor( + self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 + ).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + if self.verbose: + self.logger.info(f"batch size {batch_size} graph capturing") + + input_meta_data = InputMetaData( + block_tables=block_tables[:batch_size], + sequence_lengths=sequence_lengths[:batch_size], + fd_inter_tensor=fd_inter_tensor, + batch_size=batch_size, + is_prompts=False, + use_cuda_graph=True, + high_precision=False, + kv_seq_len=sequence_lengths[:batch_size].max().item(), + head_dim=head_dim, + dtype=self.dtype, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens_ids[:batch_size], + output_tensor[:batch_size], + input_meta_data, + k_caches=k_cache, + v_caches=v_cache, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + t_capture_end = time.perf_counter() + + if self.verbose: + self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") + + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") + if not isinstance(self.model, nn.Module): + raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" + ) + if isinstance(self.model, ModelWrapper): + model = self.model.module + assert ( + model.__class__.__name__ in _supported_models.keys() + ), f"Model {self.model.__class__.__name__} is not supported." + + def enable_spec_dec( + self, + drafter_model: nn.Module = None, + n_spec_tokens: int = None, + use_glide_drafter: bool = False, + ) -> None: + """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. + + Args: + drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. + If provided, the previous drafter and drafter model, if exist, will be overwritten. + n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. + If not provided, `max_n_spec_tokens` in InferenceConfig will be used. + use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. + If True, the drafter model will be replaced by a glide model. + + ```python + ... + engine = InferenceEngine(model, tokenizer, inference_config) + + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + engine.generate(...) # Speculative Decoding + + engine.disable_spec_dec() + engine.generate(...) # Normal generation + + engine.enable_spec_dec() + engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens + engine.clear_spec_dec() + ``` + """ + + if drafter_model is None and self.drafter is None: + raise ValueError("Drafter not initialized. Please provide a Drafter Model") + if n_spec_tokens is not None: + assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens + self.n_spec_tokens = n_spec_tokens + if drafter_model is not None: + assert isinstance(drafter_model, nn.Module) + # overwrite the drafter, if exists + self.clear_spec_dec() + self.drafter_model = drafter_model + self.drafter = Drafter( + self.drafter_model, + self.tokenizer, + device=self.device, + dtype=self.dtype, + ) + + # check if the provided drafter model is compatible with GLIDE structure + # when `use_glide_drafter` is set to True + if ( + use_glide_drafter + and hasattr(drafter_model, "model") + and hasattr(drafter_model.model, "layers") + and hasattr(drafter_model.model.layers[0], "cross_attn") + ): + self.use_glide = use_glide_drafter + elif use_glide_drafter: + self.logger.warning( + f"`use_glide_drafter` is provided as {use_glide_drafter}, " + f"but the provided drafter model is not compatible with GLIDE structure." + f"Falling back to use the default drafter model (non-GLIDE)." + ) + self.request_handler.set_spec_dec_mode(self.n_spec_tokens) + # using speculative decoding for subsequent generations + self.use_spec_dec = True + + def disable_spec_dec(self) -> None: + """Disable using speculative decoding for subsequent generations.""" + self.request_handler.unset_spec_dec_mode() + # set back to the maximum number of tokens to speculate + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.use_glide = False + self.use_spec_dec = False + + def clear_spec_dec(self) -> None: + """Clear relatable structures of speculative decoding, if exist.""" + if self.use_spec_dec: + self.disable_spec_dec() + if self.drafter_model or self.drafter: + self.drafter_model = None + self.drafter = None + torch.cuda.empty_cache() + self.use_glide = False + self.use_spec_dec = False + + def steps_spec_dec(self) -> List[Sequence]: + """ + Run Speculative Decoding steps. This is like retrieving a single batch and launch inference + with many steps of speculating by a drafter model as well as verifying by a main model. + + Returns: + List[Sequence]: finished sequences generated by one step. + """ + batch = self.request_handler.schedule() # prefill batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # 1. Prefill small model (Drafter) - fill past kv cache for drafter model + # NOTE For glide drafter models, we won't actually apply glide during prefill stage + drafter_out = self.drafter.speculate(input_token_ids, 1, None) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + + # 2. Prefill main model (Verifier) - fill past kv cache for main model + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + # append new inputs to the batch, temporarily + batch.append_batch_tokens(next_tokens) + self.request_handler.allocate_batch_spec_dec(batch, 1) + already_allocated_kv_len = batch.seq_lengths[0].item() + input_token_ids = batch.get_1D_inputs_spec_dec(1) + + finished_sequences = self.request_handler.update() + + while True: + # HACK Retrieve the running batch + # Using RequestHandler.schedule here will re-allocate same kv cache for the batch + batch = self.request_handler.running_bb # running batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + # 3. Decoding - Drafter model speculates `n` tokens + glide_input = None + if self.use_glide: + glide_input = GlideInput( + batch.get_block_table_tensor(), + self.k_cache[-1], # use kv cahces of the last layer + self.v_cache[-1], + batch.get_sequence_lengths(), + n_spec_tokens=self.n_spec_tokens, + ) + + drafter_out = self.drafter.speculate( + input_token_ids, + self.n_spec_tokens, + drafter_past_key_values, + glide_input=glide_input, + ) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + drafter_spec_length = drafter_out.speculated_length + + for next_token_id_spec in next_token_ids_spec: + self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) + cur_length = batch.seq_lengths[0].item() + if already_allocated_kv_len < cur_length: + self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) + already_allocated_kv_len = cur_length + + # 4. Decoding - Main model verifies `n` tokens in parallel + if drafter_spec_length < batch.num_tokens_to_verify: + batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + + # 5. Compare and process the results + diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) + n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + + # revoke appended tokens for each Sequence in the current batch + batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens + + # append the last correct token generated by the main model + self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) + + # trim past key values of the drafter model + drafter_past_key_values = Drafter.trim_kv_cache( + drafter_past_key_values, drafter_spec_length - n_matches - 1 + ) + + # prepare inputs for the next round of speculation + n = 1 if n_matches < drafter_spec_length else 2 + input_token_ids = batch.get_1D_inputs_spec_dec(n) + + self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) + finished_sequences = self.request_handler.update() + if len(finished_sequences) > 0: + break + + # Reset back the number of speculated tokens of the batch, + # this is used to handle the last round of speculation, in which case the number of speculated tokens + # by the drafter is less than the number of speculated tokens set to the engine. + batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) + + return finished_sequences + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + return_token_ids: bool = False, + generation_config: Optional[GenerationConfig] = None, + ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: + """ + Executing the inference step. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. + return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. + generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. + + Returns: + Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. + """ + + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None or prompts_token_ids is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + prompts_token_ids=prompts_token_ids, + **gen_config_dict, + ) + + output_seqs_list = [] + total_tokens_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + if self.use_spec_dec: + assert self.drafter is not None, "Drafter Model is not initialized." + while self.request_handler.check_unfinished_reqs(): + output_seqs_list += self.steps_spec_dec() + else: + while self.request_handler.check_unfinished_reqs(): + output_seqs_list += self.step() + + output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) + + for seq in output_seqs_list: + total_tokens_list.append(seq.input_token_id + seq.output_token_id) + + output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) + + if return_token_ids: + output_tokens_list = [seq.output_token_id for seq in output_seqs_list] + return output_str, output_tokens_list + else: + return output_str + + @property + def has_prompt_template(self) -> bool: + """ """ + return self.inference_config.prompt_template is not None + + def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: + """ + This method will format the input prompt according to the prompt template given to the InferenceConfig. + """ + assert ( + self.has_prompt_template + ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." + + if isinstance(prompts, (list, tuple)): + return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] + elif isinstance(prompts, str): + return self.inference_config.prompt_template.format(input_text=prompts) + else: + raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + + def add_request( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + **kwargs, + ) -> None: + """ + Add requests. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + """ + + # apply the prompt template to the input prompts + + if self.has_prompt_template and prompts is not None: + prompts = self.format_prompt(prompts) + + block_size = self.inference_config.block_size + + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if prompts is not None and not isinstance(prompts, list): + prompts = [prompts] + + if prompts_token_ids is None: + assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ + "input_ids" + ] + + # list of torch Tensor + if isinstance(prompts_token_ids, list): + if isinstance(prompts_token_ids[0], torch.Tensor): + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] + elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): + prompts_token_ids = prompts_token_ids.tolist() + else: + raise TypeError( + f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." + ) + + assert ( + len(prompts_token_ids[0]) <= self.inference_config.max_input_len + ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." + + prompts_num = len(prompts_token_ids) + + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + if prompts == None: + prompt = None + else: + prompt = prompts[i] + + max_length = kwargs.get("max_length", None) + max_new_tokens = kwargs.get("max_new_tokens", None) + if max_length is None and max_new_tokens is None: + max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len + elif max_length is not None: + max_new_tokens = max_length - len(prompts_token_ids[i]) + + if not self.inference_config.enable_streamingllm: + assert ( + self.inference_config.max_output_len >= max_new_tokens + ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." + + sequence = Sequence( + request_id, + prompt, + prompts_token_ids[i], + block_size, + None, + self.tokenizer.eos_token_id, + self.tokenizer.pad_token_id, + max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, + ) + self.request_handler.add_sequence(sequence) + + def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: + input_ids = batch.get_1D_inputs() + sequence_lengths = batch.get_sequence_lengths() + + if batch.is_prompts: + n_tokens = sequence_lengths.sum().item() + else: + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + output_tensor = torch.zeros( + (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + ) + + batch_token_ids = None + if ( + self.generation_config.repetition_penalty != 1.0 + or self.generation_config.no_repeat_ngram_size > 0 + or self.generation_config.forced_eos_token_id is not None + ): + batch_token_ids = batch.batch_token_ids + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=batch.fd_inter_tensor, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, + ) + + return input_ids, output_tensor, input_meta_data + + def step(self) -> List[str]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. Get the input, inputinfo and output placeholder from the batchbucket + 3. Run model to generate the next token + 4. Update waiting list and running list in RequestHandler and get finished sequences. + 5. Decode and return finished sequences. + + Returns: + List[str]: Decoded finished sequences generated by one step. + """ + + batch = self.request_handler.schedule() + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + if self.inference_config.pad_input: + logits = logits[:, -1, :] + + if self.inference_config.enable_streamingllm: + updated_block_ids = batch.streamingllm_update_batch( + self.inference_config.start_token_size, self.inference_config.generated_token_size + ) + self.request_handler.streamingllm_free_block_tables(updated_block_ids) + + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + + return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 512eaea71..393347c31 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager -from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -98,7 +98,46 @@ class RunningList: self._decoding[seq_id] = self._prefill.pop(seq_id) -class RequestHandler: +class NaiveRequestHandler: + def __init__(self) -> None: + self.running_list: List[DiffusionSequence] = [] + self.waiting_list: List[str] = [] + + def _has_waiting(self) -> bool: + return any(lst for lst in self.waiting_list) + + def _has_running(self) -> bool: + return any(lst for lst in self.running_list) + + def check_unfinished_reqs(self): + return self._has_waiting() or self._has_running() + + def add_sequence(self, seq: DiffusionSequence): + """ + Add the request to waiting list. + """ + assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists." + self.waiting_list.append(seq) + + def _find_sequence(self, request_id: int) -> DiffusionSequence: + """ + Find the request by request_id. + """ + for lst in enumerate(self.waiting_list + self.running_list): + for seq in lst: + if seq.request_id == request_id: + return seq + return None + + def schedule(self): + ret = None + if self._has_waiting: + ret = self.waiting_list[0] + self.waiting_list = self.waiting_list[1:] + return ret + + +class RequestHandler(NaiveRequestHandler): """ RequestHandler is the core for handling existing requests and updating current batch. During generation process, we call schedule function each iteration to update current batch. @@ -176,12 +215,12 @@ class RequestHandler: generated_token_size=inference_config.generated_token_size, ) + def _has_running(self) -> bool: + return not self.running_bb.is_empty() + def _init_cache(self, model_config): self.cache_manager = KVCacheManager(self.inference_config, model_config) - def _has_waiting(self) -> bool: - return any(lst for lst in self.waiting_list) - def get_kvcache(self): return self.cache_manager.get_kv_cache() @@ -318,7 +357,7 @@ class RequestHandler: if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens: seq.mark_finished() - def check_unfinished_seqs(self) -> bool: + def check_unfinished_reqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() def total_requests_in_batch_bucket(self) -> int: diff --git a/colossalai/inference/modeling/layers/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py new file mode 100644 index 000000000..9dc90733d --- /dev/null +++ b/colossalai/inference/modeling/layers/diffusion.py @@ -0,0 +1,54 @@ +import inspect +import types + +import torch +from torch import nn + + +class DiffusionPipe(nn.Module): + """ + This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property. + """ + + def __init__(self, source_obj) -> None: + super(DiffusionPipe, self).__init__() + + for k, v in source_obj.__dict__.items(): + if isinstance(v, nn.Module): + self.add_module(k, v) + else: + setattr(self, k, v) + + skip_list = ["_execution_device", "to", "device"] # this + + for name, member in inspect.getmembers(source_obj.__class__): + if name in skip_list: + continue + if not name.startswith("__") and not name.endswith("__"): + if isinstance(member, property): + setattr(self.__class__, name, member) + elif inspect.isfunction(member) or inspect.ismethod(member): + bound_method = types.MethodType(member, self) + setattr(self, name, bound_method) + elif not callable(member) and not isinstance(member, property): + setattr(self, name, member) + elif name == "__call__": + bound_method = types.MethodType(member, self) + setattr(self, "_forward", bound_method) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + # return self.device + return torch.device("cuda") + + @property + def device(self): + next(self.parameters()).device + + def forward(self, *args, **kwargs): + return self._forward(*args, **kwargs) diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py new file mode 100644 index 000000000..ea97cceef --- /dev/null +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -0,0 +1,626 @@ +# Code refer and adapted from: +# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers +# https://github.com/PipeFusion/PipeFusion + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from diffusers.models import attention_processor +from diffusers.models.attention import Attention +from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel +from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from torch import nn +from torch.distributed import ProcessGroup + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.logging import get_dist_logger +from colossalai.shardformer.layer.parallel_module import ParallelModule +from colossalai.utils import get_current_device + +try: + from flash_attn import flash_attn_func + + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + + +logger = get_dist_logger(__name__) + + +# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py +def PixArtAlphaTransformer2DModel_forward( + self: PixArtTransformer2DModel, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, +): + assert hasattr( + self, "patched_parallel_size" + ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`" + + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk( + 2, dim=1 + ) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=( + -1, + height // self.patched_parallel_size, + width, + self.config.patch_size, + self.config.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height // self.patched_parallel_size * self.config.patch_size, + width * self.config.patch_size, + ) + ) + + # enable Distrifusion Optimization + if hasattr(self, "patched_parallel_size"): + from torch import distributed as dist + + if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape): + self.output_buffer = torch.empty_like(output) + if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape): + self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] + output = output.contiguous() + dist.all_gather(self.buffer_list, output, async_op=False) + torch.cat(self.buffer_list, dim=2, out=self.output_buffer) + output = self.output_buffer + + return (output,) + + +# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py +def SD3Transformer2DModel_forward( + self: SD3Transformer2DModel, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +) -> Union[torch.FloatTensor]: + + assert hasattr( + self, "patched_parallel_size" + ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`" + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size // self.patched_parallel_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + # enable Distrifusion Optimization + if hasattr(self, "patched_parallel_size"): + from torch import distributed as dist + + if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape): + self.output_buffer = torch.empty_like(output) + if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape): + self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] + output = output.contiguous() + dist.all_gather(self.buffer_list, output, async_op=False) + torch.cat(self.buffer_list, dim=2, out=self.output_buffer) + output = self.output_buffer + + return (output,) + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py +class DistrifusionPatchEmbed(ParallelModule): + def __init__( + self, + module: PatchEmbed, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.module = module + self.rank = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + + @staticmethod + def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs): + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + distrifusion_embed = DistrifusionPatchEmbed( + module, process_group, model_shard_infer_config=model_shard_infer_config + ) + return distrifusion_embed + + def forward(self, latent): + module = self.module + if module.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size + + latent = module.proj(latent) + if module.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if module.layer_norm: + latent = module.norm(latent) + if module.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if module.pos_embed_max_size: + pos_embed = module.cropped_pos_embed(height, width) + else: + if module.height != height or module.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=module.pos_embed.shape[-1], + grid_size=(height, width), + base_size=module.base_size, + interpolation_scale=module.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = module.pos_embed + + b, c, h = pos_embed.shape + pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank] + + return (latent + pos_embed).to(latent.dtype) + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py +class DistrifusionConv2D(ParallelModule): + + def __init__( + self, + module: nn.Conv2d, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.module = module + self.rank = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + + @staticmethod + def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs): + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config) + return distrifusion_conv + + def sliced_forward(self, x: torch.Tensor) -> torch.Tensor: + + b, c, h, w = x.shape + + stride = self.module.stride[0] + padding = self.module.padding[0] + + output_h = x.shape[2] // stride // self.patched_parallelism_size + idx = dist.get_rank() + h_begin = output_h * idx * stride - padding + h_end = output_h * (idx + 1) * stride + padding + final_padding = [padding, padding, 0, 0] + if h_begin < 0: + h_begin = 0 + final_padding[2] = padding + if h_end > h: + h_end = h + final_padding[3] = padding + sliced_input = x[:, :, h_begin:h_end, :] + padded_input = F.pad(sliced_input, final_padding, mode="constant") + return F.conv2d( + padded_input, + self.module.weight, + self.module.bias, + stride=stride, + padding="valid", + ) + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.sliced_forward(input) + return output + + +# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py +class DistrifusionFusedAttention(ParallelModule): + + def __init__( + self, + module: attention_processor.Attention, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.counter = 0 + self.module = module + self.buffer_list = None + self.kv_buffer_idx = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + self.handle = None + self.process_group = process_group + self.warm_step = 5 # for warmup + + @staticmethod + def from_native_module( + module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + return DistrifusionFusedAttention( + module=module, + process_group=process_group, + model_shard_infer_config=model_shard_infer_config, + ) + + def _forward( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2) + + if self.patched_parallelism_size == 1: + full_kv = kv + else: + if self.buffer_list is None: # buffer not created + full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1) + elif self.counter <= self.warm_step: + # logger.info(f"warmup: {self.counter}") + dist.all_gather( + self.buffer_list, + kv, + group=self.process_group, + async_op=False, + ) + full_kv = torch.cat(self.buffer_list, dim=1) + else: + # logger.info(f"use old kv to infer: {self.counter}") + self.buffer_list[self.kv_buffer_idx].copy_(kv) + full_kv = torch.cat(self.buffer_list, dim=1) + assert self.handle is None, "we should maintain the kv of last step" + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/ + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + + if self.handle is not None: + self.handle.wait() + self.handle = None + + b, l, c = hidden_states.shape + kv_shape = (b, l, self.module.to_k.out_features * 2) + if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape): + + self.buffer_list = [ + torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) + for _ in range(self.patched_parallelism_size) + ] + + self.counter = 0 + + attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + output = self._forward( + self.module, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + self.counter += 1 + + return output + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py +class DistriSelfAttention(ParallelModule): + def __init__( + self, + module: Attention, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.counter = 0 + self.module = module + self.buffer_list = None + self.kv_buffer_idx = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + self.handle = None + self.process_group = process_group + self.warm_step = 3 # for warmup + + @staticmethod + def from_native_module( + module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + return DistriSelfAttention( + module=module, + process_group=process_group, + model_shard_infer_config=model_shard_infer_config, + ) + + def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): + attn = self.module + assert isinstance(attn, Attention) + + residual = hidden_states + + batch_size, sequence_length, _ = hidden_states.shape + + query = attn.to_q(hidden_states) + + encoder_hidden_states = hidden_states + k = self.module.to_k(encoder_hidden_states) + v = self.module.to_v(encoder_hidden_states) + kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2) + + if self.patched_parallelism_size == 1: + full_kv = kv + else: + if self.buffer_list is None: # buffer not created + full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1) + elif self.counter <= self.warm_step: + # logger.info(f"warmup: {self.counter}") + dist.all_gather( + self.buffer_list, + kv, + group=self.process_group, + async_op=False, + ) + full_kv = torch.cat(self.buffer_list, dim=1) + else: + # logger.info(f"use old kv to infer: {self.counter}") + self.buffer_list[self.kv_buffer_idx].copy_(kv) + full_kv = torch.cat(self.buffer_list, dim=1) + assert self.handle is None, "we should maintain the kv of last step" + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + + if HAS_FLASH_ATTN: + # flash attn + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype) + else: + # naive attn + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + *args, + **kwargs, + ) -> torch.FloatTensor: + + # async preallocates memo buffer + if self.handle is not None: + self.handle.wait() + self.handle = None + + b, l, c = hidden_states.shape + kv_shape = (b, l, self.module.to_k.out_features * 2) + if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape): + + self.buffer_list = [ + torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) + for _ in range(self.patched_parallelism_size) + ] + + self.counter = 0 + + output = self._forward(hidden_states, scale=scale) + + self.counter += 1 + return output diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py new file mode 100644 index 000000000..cc2bee5ef --- /dev/null +++ b/colossalai/inference/modeling/models/pixart_alpha.py @@ -0,0 +1,220 @@ +# Code adapted from: +# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py + +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from colossalai.logging import get_dist_logger + +from ..layers.diffusion import DiffusionPipe + +logger = get_dist_logger(__name__) + + +@torch.no_grad() +def pixart_alpha_forward( + self: DiffusionPipe, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, +) -> PIL.Image: + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + # self.maybe_free_model_hooks() + + return image diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py new file mode 100644 index 000000000..b12316403 --- /dev/null +++ b/colossalai/inference/modeling/models/stablediffusion3.py @@ -0,0 +1,178 @@ +# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps + +from ..layers.diffusion import DiffusionPipe + + +# TODO(@lry89757) temporarily image, please support more return output +@torch.no_grad() +def sd3_forward( + self: DiffusionPipe, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + return image diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py index fa0395590..02ffadd9f 100644 --- a/colossalai/inference/modeling/policy/__init__.py +++ b/colossalai/inference/modeling/policy/__init__.py @@ -1,16 +1,22 @@ from .glide_llama import GlideLlamaModelPolicy from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy from .nopadding_llama import NoPaddingLlamaModelInferPolicy +from .pixart_alpha import PixArtAlphaInferPolicy +from .stablediffusion3 import StableDiffusion3InferPolicy model_policy_map = { "nopadding_llama": NoPaddingLlamaModelInferPolicy, "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy, "glide_llama": GlideLlamaModelPolicy, + "StableDiffusion3Pipeline": StableDiffusion3InferPolicy, + "PixArtAlphaPipeline": PixArtAlphaInferPolicy, } __all__ = [ "NoPaddingLlamaModelInferPolicy", "NoPaddingBaichuanModelInferPolicy", "GlideLlamaModelPolicy", + "StableDiffusion3InferPolicy", + "PixArtAlphaInferPolicy", "model_polic_map", ] diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py new file mode 100644 index 000000000..1150b2432 --- /dev/null +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -0,0 +1,79 @@ +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel +from torch import nn + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.distrifusion import ( + DistrifusionConv2D, + DistrifusionPatchEmbed, + DistriSelfAttention, + PixArtAlphaTransformer2DModel_forward, +) +from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class PixArtAlphaInferPolicy(Policy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = {} + + if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: + + policy[PixArtTransformer2DModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="pos_embed.proj", + target_module=DistrifusionConv2D, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + SubModuleReplacementDescription( + suffix="pos_embed", + target_module=DistrifusionPatchEmbed, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + ], + attribute_replacement={ + "patched_parallel_size": self.shard_config.extra_kwargs[ + "model_shard_infer_config" + ].patched_parallelism_size + }, + method_replacement={"forward": PixArtAlphaTransformer2DModel_forward}, + ) + + policy[BasicTransformerBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn1", + target_module=DistriSelfAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, + ) + ] + ) + + self.append_or_create_method_replacement( + description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe + ) + + return policy + + def preprocess(self) -> nn.Module: + return self.model + + def postprocess(self): + return self.model + + def config_sanity_check(self): + pass + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "PixArtAlphaInferPolicy": + return PixArtAlphaInferPolicy() diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py new file mode 100644 index 000000000..39b764b92 --- /dev/null +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -0,0 +1,78 @@ +from diffusers.models.attention import JointTransformerBlock +from diffusers.models.transformers import SD3Transformer2DModel +from torch import nn + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.distrifusion import ( + DistrifusionConv2D, + DistrifusionFusedAttention, + DistrifusionPatchEmbed, + SD3Transformer2DModel_forward, +) +from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class StableDiffusion3InferPolicy(Policy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = {} + + if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: + + policy[SD3Transformer2DModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="pos_embed.proj", + target_module=DistrifusionConv2D, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + SubModuleReplacementDescription( + suffix="pos_embed", + target_module=DistrifusionPatchEmbed, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + ], + attribute_replacement={ + "patched_parallel_size": self.shard_config.extra_kwargs[ + "model_shard_infer_config" + ].patched_parallelism_size + }, + method_replacement={"forward": SD3Transformer2DModel_forward}, + ) + + policy[JointTransformerBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn", + target_module=DistrifusionFusedAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, + ) + ] + ) + + self.append_or_create_method_replacement( + description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe + ) + return policy + + def preprocess(self) -> nn.Module: + return self.model + + def postprocess(self): + return self.model + + def config_sanity_check(self): + pass + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "StableDiffusion3InferPolicy": + return StableDiffusion3InferPolicy() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1a3094a27..65d284296 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -2,6 +2,7 @@ import enum from dataclasses import dataclass from typing import Any, List +from colossalai.inference.config import DiffusionGenerationConfig from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -46,6 +47,17 @@ class RequestStatus(enum.Enum): return status == RequestStatus.WAITING +@dataclass +class DiffusionSequence: + """ + parameters for diffusion + """ + + request_id: int + prompt: str + generation_config: DiffusionGenerationConfig + + @dataclass class Sequence: """Store information of input sequence. diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 332e84d37..d0851e362 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -5,10 +5,12 @@ Utils for model inference import math import os import re +from enum import Enum from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch +from diffusers import DiffusionPipeline from torch import nn from colossalai.logging import get_dist_logger @@ -159,3 +161,36 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool: except ImportError: logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") return False + + +class ModelType(Enum): + DIFFUSION_MODEL = "Diffusion Model" + LLM = "Large Language Model (LLM)" + UNKNOWN = "Unknown Model Type" + + +def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): + if isinstance(model_or_path, DiffusionPipeline): + return ModelType.DIFFUSION_MODEL + elif isinstance(model_or_path, nn.Module): + return ModelType.LLM + elif isinstance(model_or_path, str): + try: + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + return ModelType.LLM + except: + """ + model type is not `ModelType.LLM` + """ + + try: + DiffusionPipeline.load_config(model_or_path) + return ModelType.DIFFUSION_MODEL + except: + """ + model type is not `ModelType.DIFFUSION_MODEL` + """ + else: + return ModelType.UNKNOWN diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 71d42312e..4e2eff7ce 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -3,6 +3,12 @@ import os +# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation, +# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first. +# see https://github.com/NVIDIA/Megatron-LM/issues/533 +# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16 +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + import torch.distributed as dist from colossalai.accelerator import get_accelerator diff --git a/colossalai/shardformer/layer/moe/__init__.py b/colossalai/legacy/moe/layer/__init__.py similarity index 100% rename from colossalai/shardformer/layer/moe/__init__.py rename to colossalai/legacy/moe/layer/__init__.py diff --git a/colossalai/shardformer/layer/moe/experts.py b/colossalai/legacy/moe/layer/experts.py similarity index 95% rename from colossalai/shardformer/layer/moe/experts.py rename to colossalai/legacy/moe/layer/experts.py index 1be7a2754..8088cf44e 100644 --- a/colossalai/shardformer/layer/moe/experts.py +++ b/colossalai/legacy/moe/layer/experts.py @@ -5,9 +5,9 @@ import torch import torch.nn as nn from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_activation +from colossalai.legacy.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.utils import get_activation +from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size @@ -118,7 +118,7 @@ class MLPExperts(nn.Module): Returns: torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) """ - x = MoeInGradScaler.apply(x, self.ep_size) + x = EPGradScalerIn.apply(x, self.ep_size) e = x.size(1) h = x.size(-1) @@ -157,5 +157,5 @@ class MLPExperts(nn.Module): x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() - x = MoeOutGradScaler.apply(x, self.ep_size) + x = EPGradScalerOut.apply(x, self.ep_size) return x diff --git a/colossalai/shardformer/layer/moe/layers.py b/colossalai/legacy/moe/layer/layers.py similarity index 99% rename from colossalai/shardformer/layer/moe/layers.py rename to colossalai/legacy/moe/layer/layers.py index e5b0ef97f..e43966f68 100644 --- a/colossalai/shardformer/layer/moe/layers.py +++ b/colossalai/legacy/moe/layer/layers.py @@ -7,9 +7,9 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from colossalai.legacy.moe.load_balance import LoadBalancer +from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter -from colossalai.moe.load_balance import LoadBalancer -from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator from colossalai.shardformer.layer.moe import MLPExperts from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/legacy/moe/layer/routers.py similarity index 95% rename from colossalai/shardformer/layer/moe/routers.py rename to colossalai/legacy/moe/layer/routers.py index 1be7a2754..8088cf44e 100644 --- a/colossalai/shardformer/layer/moe/routers.py +++ b/colossalai/legacy/moe/layer/routers.py @@ -5,9 +5,9 @@ import torch import torch.nn as nn from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_activation +from colossalai.legacy.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.utils import get_activation +from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size @@ -118,7 +118,7 @@ class MLPExperts(nn.Module): Returns: torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) """ - x = MoeInGradScaler.apply(x, self.ep_size) + x = EPGradScalerIn.apply(x, self.ep_size) e = x.size(1) h = x.size(-1) @@ -157,5 +157,5 @@ class MLPExperts(nn.Module): x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() - x = MoeOutGradScaler.apply(x, self.ep_size) + x = EPGradScalerOut.apply(x, self.ep_size) return x diff --git a/colossalai/moe/load_balance.py b/colossalai/legacy/moe/load_balance.py similarity index 99% rename from colossalai/moe/load_balance.py rename to colossalai/legacy/moe/load_balance.py index 3dc6c02c7..7339b1a7b 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/legacy/moe/load_balance.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torch.distributed import ProcessGroup from colossalai.cluster import ProcessGroupMesh -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER from colossalai.shardformer.layer.moe import MLPExperts from colossalai.zero.low_level import LowLevelZeroOptimizer diff --git a/colossalai/moe/manager.py b/colossalai/legacy/moe/manager.py similarity index 100% rename from colossalai/moe/manager.py rename to colossalai/legacy/moe/manager.py diff --git a/examples/language/openmoe/README.md b/colossalai/legacy/moe/openmoe/README.md similarity index 100% rename from examples/language/openmoe/README.md rename to colossalai/legacy/moe/openmoe/README.md diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py similarity index 99% rename from examples/language/openmoe/benchmark/benchmark_cai.py rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py index b9ef915c3..5f9447246 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py @@ -18,9 +18,9 @@ from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator +from colossalai.legacy.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.utils import skip_init from colossalai.moe.layers import apply_load_balance -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.sh similarity index 100% rename from examples/language/openmoe/benchmark/benchmark_cai.sh rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.sh diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai_dist.sh similarity index 100% rename from examples/language/openmoe/benchmark/benchmark_cai_dist.sh rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai_dist.sh diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py similarity index 98% rename from examples/language/openmoe/benchmark/benchmark_fsdp.py rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py index b00fbd001..1ae94dd90 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py @@ -14,7 +14,7 @@ from torch.utils.data.distributed import DistributedSampler from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER class RandomDataset(Dataset): diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.sh similarity index 100% rename from examples/language/openmoe/benchmark/benchmark_fsdp.sh rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.sh diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/colossalai/legacy/moe/openmoe/benchmark/hostfile.txt similarity index 100% rename from examples/language/openmoe/benchmark/hostfile.txt rename to colossalai/legacy/moe/openmoe/benchmark/hostfile.txt diff --git a/examples/language/openmoe/benchmark/utils.py b/colossalai/legacy/moe/openmoe/benchmark/utils.py similarity index 100% rename from examples/language/openmoe/benchmark/utils.py rename to colossalai/legacy/moe/openmoe/benchmark/utils.py diff --git a/examples/language/openmoe/infer.py b/colossalai/legacy/moe/openmoe/infer.py similarity index 100% rename from examples/language/openmoe/infer.py rename to colossalai/legacy/moe/openmoe/infer.py diff --git a/examples/language/openmoe/infer.sh b/colossalai/legacy/moe/openmoe/infer.sh similarity index 100% rename from examples/language/openmoe/infer.sh rename to colossalai/legacy/moe/openmoe/infer.sh diff --git a/examples/language/openmoe/model/__init__.py b/colossalai/legacy/moe/openmoe/model/__init__.py similarity index 100% rename from examples/language/openmoe/model/__init__.py rename to colossalai/legacy/moe/openmoe/model/__init__.py diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.py similarity index 100% rename from examples/language/openmoe/model/convert_openmoe_ckpt.py rename to colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.py diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.sh similarity index 100% rename from examples/language/openmoe/model/convert_openmoe_ckpt.sh rename to colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.sh diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/colossalai/legacy/moe/openmoe/model/modeling_openmoe.py similarity index 99% rename from examples/language/openmoe/model/modeling_openmoe.py rename to colossalai/legacy/moe/openmoe/model/modeling_openmoe.py index 1febacd7d..5d6e91765 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/colossalai/legacy/moe/openmoe/model/modeling_openmoe.py @@ -50,8 +50,8 @@ try: except: HAS_FLASH_ATTN = False from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_activation, set_moe_args +from colossalai.legacy.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.utils import get_activation, set_moe_args from colossalai.shardformer.layer.moe import SparseMLP if HAS_TRITON: diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/colossalai/legacy/moe/openmoe/model/openmoe_8b_config.json similarity index 100% rename from examples/language/openmoe/model/openmoe_8b_config.json rename to colossalai/legacy/moe/openmoe/model/openmoe_8b_config.json diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/colossalai/legacy/moe/openmoe/model/openmoe_base_config.json similarity index 100% rename from examples/language/openmoe/model/openmoe_base_config.json rename to colossalai/legacy/moe/openmoe/model/openmoe_base_config.json diff --git a/examples/language/openmoe/model/openmoe_policy.py b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py similarity index 99% rename from examples/language/openmoe/model/openmoe_policy.py rename to colossalai/legacy/moe/openmoe/model/openmoe_policy.py index f46062128..ccd566b08 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py @@ -9,7 +9,7 @@ from torch.nn import Module from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/examples/language/openmoe/requirements.txt b/colossalai/legacy/moe/openmoe/requirements.txt similarity index 100% rename from examples/language/openmoe/requirements.txt rename to colossalai/legacy/moe/openmoe/requirements.txt diff --git a/examples/language/openmoe/test_ci.sh b/colossalai/legacy/moe/openmoe/test_ci.sh similarity index 100% rename from examples/language/openmoe/test_ci.sh rename to colossalai/legacy/moe/openmoe/test_ci.sh diff --git a/examples/language/openmoe/train.py b/colossalai/legacy/moe/openmoe/train.py similarity index 99% rename from examples/language/openmoe/train.py rename to colossalai/legacy/moe/openmoe/train.py index ff0e4bad6..0173f0964 100644 --- a/examples/language/openmoe/train.py +++ b/colossalai/legacy/moe/openmoe/train.py @@ -19,7 +19,7 @@ from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.moe.utils import skip_init +from colossalai.legacy.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer.layer.moe import apply_load_balance diff --git a/examples/language/openmoe/train.sh b/colossalai/legacy/moe/openmoe/train.sh similarity index 100% rename from examples/language/openmoe/train.sh rename to colossalai/legacy/moe/openmoe/train.sh diff --git a/colossalai/moe/utils.py b/colossalai/legacy/moe/utils.py similarity index 99% rename from colossalai/moe/utils.py rename to colossalai/legacy/moe/utils.py index 3d08ab7dd..d91c41363 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/legacy/moe/utils.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch.distributed.distributed_c10d import get_process_group_ranks from colossalai.accelerator import get_accelerator -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import is_moe_tensor diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index f01da97ba..8b8f04ccf 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 0623d19ef..e69de29bb 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,5 +0,0 @@ -from .manager import MOE_MANAGER - -__all__ = [ - "MOE_MANAGER", -] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 01c837ee3..ac422a4da 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False): return torch.cumsum(inputs, dim=0) - 1 -class MoeInGradScaler(torch.autograd.Function): +class EPGradScalerIn(torch.autograd.Function): """ Scale the gradient back by the number of experts because the batch size increases in the moe stage @@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: - if ctx is not None: - ctx.ep_size = ep_size + ctx.ep_size = ep_size return inputs @staticmethod @@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function): return grad, None -class MoeOutGradScaler(torch.autograd.Function): +class EPGradScalerOut(torch.autograd.Function): """ Scale the gradient by the number of experts because the batch size increases in the moe stage @@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function): return grad, None +class DPGradScalerIn(torch.autograd.Function): + """ + Scale the gradient back by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor: + assert activated_experts != 0, f"shouldn't be called when no expert is activated" + ctx.moe_dp_size = moe_dp_size + ctx.activated_experts = activated_experts + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.moe_dp_size != ctx.activated_experts: + grad.mul_(ctx.activated_experts / ctx.moe_dp_size) + return grad, None, None + + +class DPGradScalerOut(torch.autograd.Function): + """ + Scale the gradient by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor: + assert activated_experts != 0, f"shouldn't be called when no expert is activated" + ctx.moe_dp_size = moe_dp_size + ctx.activated_experts = activated_experts + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.moe_dp_size != ctx.activated_experts: + grad.mul_(ctx.moe_dp_size / ctx.activated_experts) + return grad, None, None + + def _all_to_all( inputs: torch.Tensor, input_split_sizes: Optional[List[int]] = None, @@ -393,4 +436,7 @@ def all_to_all_uneven( group=None, overlap: bool = False, ): + assert ( + inputs.requires_grad + ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index ed190eb08..b7b284213 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -91,7 +91,11 @@ def _broadcast_object_list( my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - if Version(torch.__version__) >= Version("1.13.0"): + if Version(torch.__version__) >= Version("2.3.0"): + tensor_list, size_list = zip( + *[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list] + ) + elif Version(torch.__version__) >= Version("1.13.0"): tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) else: tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) @@ -276,7 +280,11 @@ def _send_recv_serialization_object( send_object_tensor = None send_object_size_tensor = None if object is not None and send_dst is not None: - if Version(torch.__version__) >= Version("1.13.0"): + if Version(torch.__version__) >= Version("2.3.0"): + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor( + object, device=current_device, group=send_group + ) + elif Version(torch.__version__) >= Version("1.13.0"): send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) else: send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index f17fad1b6..331e49729 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d +from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -18,6 +18,7 @@ __all__ = [ "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", + "dist_cross_entropy", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index a27fd35c1..feebd2d05 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -146,7 +146,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if use_bias: bias.view(bias.shape) - total_input = input + total_input = input.contiguous() grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 141baf3d3..5872c6485 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -139,12 +139,11 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) + if s_q != 1: + attention_mask = attention_mask.tril(diagonal=0) + attention_mask = attention_mask.expand(b, s_q, s_kv) else: - assert q_padding_mask.shape == ( - b, - s_q, - ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})" max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention @@ -156,7 +155,7 @@ class ColoAttention: b, s_kv, ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" - attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device) + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -169,7 +168,8 @@ class ColoAttention: ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + if s_q != 1: + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED attention_mask = invert_mask(attention_mask).unsqueeze(1) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index a6d19edf5..cea2da03f 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,8 +2,11 @@ import torch import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss -__all__ = ["DistCrossEntropy", "cross_entropy_1d"] +from colossalai.shardformer.shard import ShardConfig + +__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] class DistCrossEntropy(Function): @@ -132,3 +135,43 @@ def cross_entropy_1d( dtype: torch.dtype = None, ) -> torch.Tensor: return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) + + +def dist_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + shard_config: ShardConfig, + out_features: int, + vocab_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Helper to compute cross entropy loss for most shardformer models, + compatible with PP, TP and SP. + """ + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=out_features, + dtype=dtype, + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered. + # see VocabParallelLMHead1D + shift_logits = shift_logits.view(-1, vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + return loss diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index d8425b58d..93a7eb231 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -722,6 +722,7 @@ class FusedLinear1D_Col(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, + n_fused=n_fused, *args, **kwargs, ) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 154143626..26ffef6c5 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -359,30 +359,14 @@ class BloomPipelineForwards: hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1040,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 53c151f02..34d900d8d 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -11,7 +11,11 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.layer import AttnMaskType, ColoAttention -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) def get_flash_core_attention_forward(): @@ -203,6 +207,13 @@ class ChatGLMPipelineForwards: dim=0, process_group=shard_config.tensor_parallel_process_group, ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -235,6 +246,13 @@ class ChatGLMPipelineForwards: dim=0, process_group=shard_config.tensor_parallel_process_group, ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -329,7 +347,9 @@ class ChatGLMPipelineForwards: return transformer_outputs -def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, sp_size, sp_group): + logger = logging.get_logger(__name__) + def forward( self, input_ids, @@ -381,13 +401,27 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + if sp_mode in ["all_to_all"] and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..." + ) + use_cache = False # Run encoder. # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] - inputs_embeds = split_forward_gather_backward( - inputs_embeds, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if sp_mode in ["split_gather"]: + inputs_embeds = split_forward_gather_backward( + inputs_embeds, + dim=0, + process_group=sp_group, + ) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward( + inputs_embeds, + dim=0, + process_group=sp_group, + grad_scale=1 / sp_size, + ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, @@ -397,11 +431,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): output_hidden_states=output_hidden_states, ) - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if sp_mode in ["split_gather"]: + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, + ) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=sp_group, + grad_scale=sp_size, + ) if not return_dict: return tuple( @@ -423,3 +465,158 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ) return forward + + +def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, sp_mode, sp_size, sp_group): + from .chatglm2_6b.modeling_chatglm import apply_rotary_pos_emb, split_tensor_along_last_dim + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + sq, bs, _, _ = value_layer.size() + + query_layer = query_layer.reshape(sq, bs, -1) + key_layer = key_layer.reshape(sq, bs, -1) + value_layer = value_layer.reshape(sq, bs, -1) + + query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0) + key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0) + value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0) + + query_layer = query_layer.view( + sq * sp_size, + bs, + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ).contiguous() + + key_layer = key_layer.view( + sq * sp_size, + bs, + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ).contiguous() + + value_layer = value_layer.view( + sq * sp_size, + bs, + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ).contiguous() + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + ( + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + if sp_mode == "all_to_all": + context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0) + + # ================= + # Output. [sq, b, h] + # ================= + output = self.dense(context_layer) + + return output, kv_cache + + return forward diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 07a7f6cbf..5b36fc7db 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( @@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import ( ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class CommandPipelineForwards: @@ -117,7 +116,7 @@ class CommandPipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -135,6 +134,21 @@ class CommandPipelineForwards: ) use_cache = False + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -191,6 +205,21 @@ class CommandPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + ) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -300,29 +329,9 @@ class CommandPipelineForwards: logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -658,24 +667,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py new file mode 100644 index 000000000..a84a30972 --- /dev/null +++ b/colossalai/shardformer/modeling/deepseek.py @@ -0,0 +1,754 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.utils import is_flash_attn_2_available, logging + +from colossalai.lazy import LazyInitContext +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.shard import ShardConfig +from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group + + +# copied from modeling_deepseek.py +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class EPDeepseekMoE(nn.Module): + def __init__(self): + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") + + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + assert tp_group is not None + assert moe_dp_group is not None + assert ep_group is not None + + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) + self.num_experts = self.config.n_routed_experts + assert self.num_experts % self.ep_size == 0 + + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + + set_tensors_to_none(self.experts, exclude=set(held_experts)) + + # setup moe_dp group + self.moe_dp_group = moe_dp_group + self.moe_dp_size = moe_dp_group.size() + + # setup tp group + self.tp_group = tp_group + if self.tp_group.size() > 1: + for expert in held_experts: + expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group) + expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group) + expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group) + + for p in self.experts.parameters(): + set_moe_tensor_ep_group(p, ep_group) + + @staticmethod + def from_native_module( + module, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + *args, + **kwargs, + ) -> "EPDeepseekMoE": + LazyInitContext.materialize(module) + if module.__class__.__name__ == "DeepseekMLP": + return module + module.__class__ = EPDeepseekMoE + module.setup_process_groups(tp_group, moe_dp_group, ep_group) + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + identity = hidden_states + orig_shape = hidden_states.shape + + topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...] + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ] + + flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...] + # The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids. + flat_topk_token_idx = flat_topk_experts_idx.argsort() + + # Now we adjust the order of the hidden states, also in ascending order of expert id + dispatch_states = hidden_states[flat_topk_token_idx] + input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3] + output_split_sizes = torch.zeros_like(input_split_sizes) + + # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + + with torch.no_grad(): + activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() + for i in range(1, self.ep_size): + activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] + activate_experts = (activate_experts > 0).float() + dist.all_reduce(activate_experts, group=self.moe_dp_group) + + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states = EPGradScalerIn.apply(output_states, self.ep_size) + + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + expert = self.experts[self.expert_start_idx] + output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0]) + output_states = expert(output_states) + output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0]) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: # no token routed to this experts + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = DPGradScalerIn.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) + split_states = expert(split_states) + split_states = DPGradScalerOut.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = EPGradScalerOut.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_token_idx = torch.empty_like(flat_topk_token_idx) + recover_token_idx[flat_topk_token_idx] = torch.arange( + flat_topk_token_idx.size(0), device=flat_topk_token_idx.device + ) + + output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2 + output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1]) + output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (B*S, h) + output_hidden_states = output_hidden_states.view(*orig_shape) + output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss) + if self.config.n_shared_experts is not None: + output_hidden_states = output_hidden_states + self.shared_experts(identity) + return output_hidden_states + + +class DeepseekPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def deepseek_model_forward( + self: "DeepseekModel", + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if is_flash_attn_2_available(): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + } + + @staticmethod + def deepseek_for_causal_lm_forward( + self: "DeepseekForCausalLM", + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = DeepseekPipelineForwards.deepseek_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + return out + + +def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + # DeepseekFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0 + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + return forward + + +def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index beaa47952..0dbf0ca5a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -372,27 +372,9 @@ class GPT2PipelineForwards: hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1284,24 +1266,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f83735f05..693f6584f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import ( ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class LlamaPipelineForwards: @@ -86,13 +86,20 @@ class LlamaPipelineForwards: device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device + # Support SP + PP + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): + # For correct positions ids. The states will be gather along the seq dim in the attention layer later. + seq_length *= sp_size + past_seen_tokens = 0 if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): @@ -101,7 +108,7 @@ class LlamaPipelineForwards: if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) seq_length_with_past = seq_length + past_seen_tokens @@ -118,7 +125,6 @@ class LlamaPipelineForwards: if position_ids is None: position_ids = cache_position.unsqueeze(0) - # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: @@ -134,6 +140,13 @@ class LlamaPipelineForwards: else: attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + # Support SP + PP + if stage_manager.is_first_stage(): + if sp_mode in ["ring", "split_gather"]: + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + if self.gradient_checkpointing and self.training and use_cache: if use_cache: logger.warning_once( @@ -196,6 +209,10 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: @@ -304,29 +321,9 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -649,7 +645,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: - mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, @@ -814,24 +810,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 310c2d8e2..ec1a8a00a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -19,7 +19,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy logger = logging.get_logger(__name__) @@ -91,7 +91,7 @@ class MistralForwards: if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length) + mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -275,29 +275,9 @@ class MistralForwards: logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 2fbc34302..d30ce5ea8 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -1,52 +1,105 @@ -from typing import List, Optional +import inspect +import warnings +from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn.functional as F from torch.distributed import ProcessGroup - -# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo from torch.nn import CrossEntropyLoss -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralSparseMoeBlock, MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + apply_rotary_pos_emb, load_balancing_loss_func, + repeat_kv, ) from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func + + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): - def __init__(self, config): - self.moe_info = None - super().__init__(config) + def __init__(self, *args, **kwargs): + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_ep(self, ep_group: ProcessGroup): - ep_group = ep_group - self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 - self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 - assert self.num_experts % self.ep_size == 0 + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + assert tp_group is not None + assert moe_dp_group is not None + assert ep_group is not None + + # setup ep group + self.ep_size = dist.get_world_size(ep_group) + self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group + + if self.num_experts % self.ep_size != 0: + raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") + self.num_experts_per_ep = self.num_experts // self.ep_size self.expert_start_idx = self.ep_rank * self.num_experts_per_ep held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + + # setup moe_dp group + self.moe_dp_group = moe_dp_group + self.moe_dp_size = moe_dp_group.size() + + # setup global tp group + self.tp_group = tp_group + if self.tp_group.size() > 1: + for expert in held_experts: + expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) + expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) + expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) + for p in self.experts.parameters(): - p.ep_group = ep_group + set_moe_tensor_ep_group(p, ep_group) @staticmethod - def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + def from_native_module( + module: MixtralSparseMoeBlock, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + *args, + **kwargs, + ) -> "EPMixtralSparseMoeBlock": + # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - # if "ep_group" in kwargs: - assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!" - module.setup_ep(kwargs["ep_group"]) + module.setup_process_groups(tp_group, moe_dp_group, ep_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -65,20 +118,31 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): selected_experts_idx = selected_experts.argsort() dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + output_split_sizes = torch.zeros_like(input_split_sizes) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + with torch.no_grad(): + activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() + for i in range(1, self.ep_size): + activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] + activate_experts = (activate_experts > 0).float() + dist.all_reduce(activate_experts, group=self.moe_dp_group) + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) # compute expert output - output_states = MoeInGradScaler.apply(output_states, self.ep_size) + output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: if self.num_experts_per_ep == 1: # no need to split expert = self.experts[self.expert_start_idx] + output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0]) output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) output_states = expert.w2(output_states) + output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0]) else: output_states_splits = output_states.split(output_split_sizes.tolist()) output_states_list = [] @@ -86,12 +150,20 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): if split_states.size(0) == 0: continue expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = DPGradScalerIn.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.w2(split_states) + split_states = DPGradScalerOut.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep] + ) output_states_list.append(split_states) output_states = torch.cat(output_states_list) - output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + + output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( selected_experts_idx.size(0), device=selected_experts_idx.device @@ -107,7 +179,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): class MixtralPipelineForwards: """ - This class serves as a micro library for forward function substitution of Llama models + This class serves as a micro library for forward function substitution of Mixtral models under pipeline setting. """ @@ -300,16 +372,29 @@ class MixtralPipelineForwards: if output_router_logits and past_router_logits is not None: all_router_logits = past_router_logits + all_router_logits if stage_manager.is_last_stage(): - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - # always return dict for imediate stage - return { - "hidden_states": hidden_states, - "past_router_logits": all_router_logits, - } + else: + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + else: + return { + "hidden_states": hidden_states, + } @staticmethod def mixtral_for_causal_lm_forward( @@ -441,3 +526,335 @@ class MixtralPipelineForwards: if output_router_logits: out["past_router_logits"] = outputs["past_router_logits"] return out + + +def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index b250b4976..636b46cc4 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -330,30 +330,14 @@ class OPTPipelineForwards: ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) logits = self.lm_head(outputs[0]).contiguous() - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 11c26822f..538e96c32 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -1,6 +1,8 @@ +import math from typing import List, Optional, Tuple, Union import torch +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -30,9 +32,14 @@ except ImportError: from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class Qwen2PipelineForwards: @@ -129,7 +136,7 @@ class Qwen2PipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -162,6 +169,21 @@ class Qwen2PipelineForwards: sliding_window=self.config.sliding_window, ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -218,6 +240,20 @@ class Qwen2PipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -317,25 +353,9 @@ class Qwen2PipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -469,7 +489,7 @@ class Qwen2PipelineForwards: return {"hidden_states": hidden_states} -def get_qwen2_flash_attention_forward(shard_config: ShardConfig): +def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self: Qwen2Attention, hidden_states: torch.Tensor, @@ -480,11 +500,26 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig): use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -538,10 +573,41 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -549,9 +615,8 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig): return forward -def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): +def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) - assert shard_config.enable_flash_attention, "Flash Attention is not enabled." def forward( self, @@ -586,6 +651,10 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): seq_length_with_past = seq_length past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -601,17 +670,26 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): # embed positions hidden_states = inputs_embeds - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) - if self.gradient_checkpointing and self.training: + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -623,6 +701,11 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): all_self_attns = () if output_attentions else None next_decoder_cache = None + if sp_mode in ["ring", "split_gather"]: + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -657,6 +740,11 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -737,26 +825,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index bf139c840..7b9c759a6 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -160,6 +160,13 @@ _POLICY_LIST = { "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), + # Deepseek + "transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation( + file_name="deepseek", class_name="DeepseekModelPolicy" + ), + "transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation( + file_name="deepseek", class_name="DeepseekForCausalLMPolicy" + ), # Falcon "transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation( file_name="falcon", class_name="FalconModelPolicy" @@ -193,6 +200,9 @@ _POLICY_LIST = { "transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation( file_name="mixtral", class_name="MixtralForCausalLMPolicy" ), + "transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation( + file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy" + ), # Qwen2 "transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation( file_name="qwen2", class_name="Qwen2ModelPolicy" @@ -233,6 +243,9 @@ def _fullname(obj): # patch custom models which are not in transformers # it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub) # or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory) + if module.startswith("peft"): + klass = obj.base_model.model.__class__ + module = klass.__module__ if module.startswith("transformers_modules"): split_module = module.split(".") if len(split_module) >= 2: @@ -252,7 +265,6 @@ def get_autopolicy(model: nn.Module) -> Policy: """ full_name = _fullname(model) policy_location = _POLICY_LIST.get(full_name, None) - if policy_location is None: raise NotImplementedError( f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 01aa77e57..3877bdac3 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -9,6 +9,7 @@ import colossalai.shardformer.layer as col_nn from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_attention_forward, get_chatglm_sequence_parallel_forward_fn, get_flash_core_attention_forward, get_jit_fused_glm_block_forward, @@ -58,14 +59,29 @@ class ChatGLMPolicy(Policy): norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode or None - assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + if sp_mode == "ring": warnings.warn( f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" ) sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode == "split_gather" + sp_partial_derived = sp_mode in ["split_gather"] + + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + "hidden_size_per_partition": self.model.config.kv_channels + * self.model.config.num_attention_heads + // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + policy["CoreAttention"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) if self.shard_config.enable_tensor_parallelism: assert ( @@ -179,12 +195,26 @@ class ChatGLMPolicy(Policy): ) # use sequence parallel - if sp_mode == "split_gather": + if self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( - description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + description={ + "forward": get_chatglm_sequence_parallel_attention_forward( + self.shard_config, sp_mode, sp_size, sp_group + ), + }, policy=policy, - target_key="ChatGLMModel", + target_key="SelfAttention", ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_chatglm_sequence_parallel_forward_fn( + self.shard_config, sp_mode, sp_size, sp_group + ) + }, + policy=policy, + target_key="ChatGLMModel", + ) # use jit fused operator if self.shard_config.enable_jit_fused: diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 902baf2e1..a9b915d10 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -66,13 +65,6 @@ class CommandPolicy(Policy): else: norm_cls = LayerNorm - if self.pipeline_stage_manager is not None: - self.shard_config.enable_sequence_parallelism = False - self.shard_config.enable_sequence_overlap = False - self.shard_config.sequence_parallelism_mode = None - warnings.warn( - f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" - ) sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py new file mode 100644 index 000000000..605f69c4a --- /dev/null +++ b/colossalai/shardformer/policies/deepseek.py @@ -0,0 +1,347 @@ +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module +from transformers.utils import is_flash_attn_greater_or_equal_2_10 + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D +from colossalai.shardformer.layer.linear import Linear1D_Row +from colossalai.shardformer.modeling.deepseek import ( + DeepseekPipelineForwards, + EPDeepseekMoE, + get_deepseek_flash_attention_forward, + get_deepseek_flash_attention_model_forward, +) +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] + + +class DeepseekPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation + """ + Because transformers library's bug for AutoModel/AutoConfig, who pop “attn_implement” twice from modeling_utils.py and configuration_utils.py. + This bug causes attn_cls to be set to sdpa. Here we assign it to "flash_attention_2". + """ + # self.origin_attn_implement = "flash_attention_2" + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + + ATTN_IMPLEMENTATION = { + "eager": "DeepseekAttention", + "flash_attention_2": "DeepseekFlashAttention2", + "sdpa": "DeepseekSdpaAttention", + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + if self.shard_config.enable_sequence_parallelism: + if self.pipeline_stage_manager is not None: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") + self.append_or_create_method_replacement( + description={ + "forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_deepseek_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key="DeepseekModel", + ) + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_tensor_parallelism: + # tensor parallelism for non-moe params + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + + policy["DeepseekDecoderLayer"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + ], + ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key="DeepseekModel", + ) + + if self.shard_config.ep_group: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="mlp", + target_module=EPDeepseekMoE, + kwargs={ + "ep_group": self.shard_config.ep_group, + "tp_group": self.shard_config.tensor_parallel_process_group, + "moe_dp_group": self.shard_config.moe_dp_group, + }, + ) + ], + policy=policy, + target_key="DeepseekDecoderLayer", + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + ], + policy=policy, + target_key="DeepseekDecoderLayer", + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + policy=policy, + target_key="DeepseekModel", + ) + + if self.shard_config.enable_flash_attention: + # NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + flash_attn_cls = get_class_from_dynamic_module( + "deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2", + "deepseek-ai/deepseek-moe-16b-base", + ) + + class TargetFlashAttn: + def __init__(self): + raise RuntimeError("This class should not be instantiated") + + @staticmethod + def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module: + original_attn.__class__ = flash_attn_cls + original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + return original_attn + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="self_attn", + target_module=TargetFlashAttn, + ), + policy=policy, + target_key="DeepseekDecoderLayer", + ) + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + if self.shard_config.enable_sequence_parallelism: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "DeepseekModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "DeepseekModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class DeepseekModelPolicy(DeepseekPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls="DeepseekModel", + new_forward=DeepseekPipelineForwards.deepseek_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class DeepseekForCausalLMPolicy(DeepseekPolicy): + def module_policy(self): + policy = super().module_policy() + # TODO: assign pg mesh from plugin to all modules + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + "DeepseekForCausalLM": ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls="DeepseekForCausalLM", + new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + deepseek_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: deepseek_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 0f28f6cf4..6f8404219 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -65,13 +64,7 @@ class LlamaPolicy(Policy): norm_cls = FusedRMSNorm else: norm_cls = RMSNorm - if self.pipeline_stage_manager is not None: - self.shard_config.enable_sequence_parallelism = False - self.shard_config.enable_sequence_overlap = False - self.shard_config.sequence_parallelism_mode = None - warnings.warn( - f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" - ) + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 0fb858d78..10df143c9 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -1,13 +1,21 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor from torch.nn import Module -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col -from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards +from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D +from colossalai.shardformer.layer.linear import Linear1D_Row +from colossalai.shardformer.modeling.mixtral import ( + EPMixtralSparseMoeBlock, + MixtralPipelineForwards, + get_mixtral_flash_attention_forward, + get_mixtral_flash_attention_model_forward, +) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] @@ -18,36 +26,136 @@ class MixtralPolicy(Policy): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - policy = {} + from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralFlashAttention2, + MixtralModel, + MixtralSdpaAttention, + ) + ATTN_IMPLEMENTATION = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - raise NotImplementedError( - "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + if self.pipeline_stage_manager is not None: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") + self.append_or_create_method_replacement( + description={ + "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_mixtral_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=MixtralModel, ) + embedding_cls = None if self.shard_config.enable_tensor_parallelism: - raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - if getattr(self.shard_config, "ep_group", None) is not None: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if self.shard_config.enable_tensor_parallelism: + # tensor parallelism for non-moe params + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + + policy[MixtralDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( # or replicate? + suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True} + ), + ], + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=MixtralModel, + ) + + if self.shard_config.ep_group: # expert parallel self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( suffix="block_sparse_moe", target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, + kwargs={ + "ep_group": self.shard_config.ep_group, + "tp_group": self.shard_config.tensor_parallel_process_group, + "moe_dp_group": self.shard_config.moe_dp_group, + }, ) ], policy=policy, @@ -61,10 +169,12 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -75,13 +185,15 @@ class MixtralPolicy(Policy): description=SubModuleReplacementDescription( suffix="norm", target_module=FusedRMSNorm, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=MixtralModel, ) if self.shard_config.enable_flash_attention: - raise NotImplementedError("Flash attention has already been replaced in mixtral.") + warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.") + self.shard_config.enable_flash_attention = False return policy @@ -92,6 +204,10 @@ class MixtralPolicy(Policy): """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: + if self.shard_config.enable_sequence_parallelism: + # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism + # if both are enabled, one of them will be ignored + raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.") stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "MixtralModel": module = self.model @@ -150,7 +266,7 @@ class MixtralModelPolicy(MixtralPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" + """No shared params in mixtral model""" return [] @@ -192,17 +308,54 @@ class MixtralForCausalLMPolicy(MixtralPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - llama_model = self.model.model + mixtral_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( - id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1 ): # tie weights return [ { - 0: llama_model.embed_tokens.weight, + 0: mixtral_model.embed_tokens.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] return [] + + +class MixtralForSequenceClassificationPolicy(MixtralPolicy): + def module_policy(self): + from transformers import MixtralForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + MixtralForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + raise NotImplementedError + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in mixtral for sequence classification model""" + return [] diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 3e427c4a1..362c14060 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -82,9 +81,20 @@ class Qwen2Policy(Policy): embedding_cls = PaddingEmbedding norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm - if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) if self.shard_config.enable_tensor_parallelism: assert ( @@ -109,30 +119,37 @@ class Qwen2Policy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), ], ) @@ -154,10 +171,12 @@ class Qwen2Policy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -168,16 +187,16 @@ class Qwen2Policy(Policy): description=SubModuleReplacementDescription( suffix="norm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=Qwen2Model, ) - # use flash attention - if self.shard_config.enable_flash_attention: + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( description={ - "forward": get_qwen2_flash_attention_forward(self.shard_config), + "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, @@ -186,7 +205,9 @@ class Qwen2Policy(Policy): # replace qwen2 model forward method self.append_or_create_method_replacement( description={ - "forward": get_qwen2_model_forward_for_flash_attn(self.shard_config), + "forward": get_qwen2_model_forward_for_flash_attn( + self.shard_config, sp_mode, sp_size, sp_group + ), }, policy=policy, target_key=Qwen2Model, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 7372e06c2..8e33d786b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -47,6 +47,9 @@ class ShardConfig: make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + + # for moe related + moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None fp8_communication: bool = False # pipeline_parallel_size: int diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b54c58273..db03eec41 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,4 +1,3 @@ -import os from typing import Dict, List, Tuple import torch.distributed as dist @@ -11,9 +10,6 @@ from ..policies.base_policy import Policy from .shard_config import ShardConfig from .sharder import ModelSharder -# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - class ShardFormer: """ diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index c2cf73181..0f0150d90 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -473,7 +473,7 @@ class LayoutConverter(metaclass=SingletonMeta): for process_group in used_process_groups: try: dist.get_rank(process_group) - except RuntimeError as e: + except (ValueError, RuntimeError) as e: # If the group is not registered, it means it has been deleted if str(e) == ( f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 16a4f248b..76d85a112 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Dict, List from ..utils import merge_same_dim_mesh_list @@ -23,10 +22,11 @@ class DimSpec: Otherwise, the element in shard_list means the data will be sharded in that dimension. """ + _DIFFERENCE_DICT = None + def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list - self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -39,24 +39,43 @@ class DimSpec: target += str(dim) return target - def _convert_str_to_shard_list(self, str_spec): + @property + def difference_dict(self): """ - Convert str_spec into shard_list. + Returns the difference dict, and lazily initializes it when needed + + Return: + difference_dict(Dict[Tuple[int, int], Union[int, float, str]]): + difference dict + """ + if self._DIFFERENCE_DICT is None: + self._DIFFERENCE_DICT = self._build_difference_2d_dict() + + return self._DIFFERENCE_DICT + + def dim_diff(self, other): + """ + The difference between two DimSpec. Argument: - str_spec(str): dim spec in str type. + other(DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two DimSpec. + + Example: + dim_spec = DimSpec([0]) + other_dim_spec = DimSpec([0, 1]) + print(dim_spec.dim_diff(other_dim_spec)) + + Output: + 5 """ + difference = self.difference_dict[(str(self), str(other))] + return difference - if str_spec == "R": - return [] - if str_spec == "S0": - return [0] - if str_spec == "S1": - return [1] - if str_spec == "S01": - return [0, 1] - - def build_difference_2d_dict(self): + @classmethod + def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. @@ -67,9 +86,8 @@ class DimSpec: difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) - source_shard_list = self._convert_str_to_shard_list(source_spec) - target_shard_list = self._convert_str_to_shard_list(target_spec) + source_shard_list = cls._convert_str_to_shard_list(source_spec) + target_shard_list = cls._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: @@ -112,30 +130,27 @@ class DimSpec: else: difference = NAN - difference_dict[spec_pair] = difference + difference_dict[(source_spec, target_spec)] = difference - self.difference_dict = difference_dict + return difference_dict - def dim_diff(self, other): + @staticmethod + def _convert_str_to_shard_list(str_spec): """ - The difference between two _DimSpec. + Convert str_spec into shard_list. Argument: - other(_DimSpec): the dim spec to compare with. - - Return: - difference(int): the difference between two _DimSpec. - - Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + str_spec(str): dim spec in str type. """ - difference = self.difference_dict[(str(self), str(other))] - return difference + + if str_spec == "R": + return [] + if str_spec == "S0": + return [0] + if str_spec == "S1": + return [1] + if str_spec == "S01": + return [0, 1] class ShardingSpec: diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index b78ef6d97..fb42afab7 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,5 +1,4 @@ import operator -from copy import deepcopy from functools import reduce import torch @@ -27,10 +26,11 @@ class _DimSpec: Otherwise, the element in shard_list means the data will be sharded in that dimension. """ + _DIFFERENCE_DICT = None + def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list - self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -43,27 +43,46 @@ class _DimSpec: target += str(dim) return target - def _convert_str_to_shard_list(self, str_spec): + @property + def difference_dict(self): """ - Convert str_spec into shard_list. + Returns the difference dict, and lazily initializes it when needed + + Return: + difference_dict(Dict[Tuple[int, int], Union[int, float, str]]): + difference dict + """ + if self._DIFFERENCE_DICT is None: + self._DIFFERENCE_DICT = self._build_difference_2d_dict() + + return self._DIFFERENCE_DICT + + def difference(self, other): + """ + The difference between two _DimSpec. Argument: - str_spec(str): dim spec in str type. + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 """ + difference = self.difference_dict[(str(self), str(other))] + return difference - if str_spec == "R": - return [] - if str_spec == "S0": - return [0] - if str_spec == "S1": - return [1] - if str_spec == "S01": - return [0, 1] - - def build_difference_2d_dict(self): + @classmethod + def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to - compute the difference between DimSpec pairs. + compute the difference between _DimSpec pairs. """ source_spec_list = ["R", "S0", "S1", "S01"] @@ -71,9 +90,8 @@ class _DimSpec: difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) - source_shard_list = self._convert_str_to_shard_list(source_spec) - target_shard_list = self._convert_str_to_shard_list(target_spec) + source_shard_list = cls._convert_str_to_shard_list(source_spec) + target_shard_list = cls._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: @@ -116,30 +134,27 @@ class _DimSpec: else: difference = NAN - difference_dict[spec_pair] = difference + difference_dict[(source_spec, target_spec)] = difference - self.difference_dict = difference_dict + return difference_dict - def difference(self, other): + @staticmethod + def _convert_str_to_shard_list(str_spec): """ - The difference between two _DimSpec. + Convert str_spec into shard_list. Argument: - other(_DimSpec): the dim spec to compare with. - - Return: - difference(int): the difference between two _DimSpec. - - Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + str_spec(str): dim spec in str type. """ - difference = self.difference_dict[(str(self), str(other))] - return difference + + if str_spec == "R": + return [] + if str_spec == "S0": + return [0] + if str_spec == "S1": + return [1] + if str_spec == "S01": + return [0, 1] class ShardingSpecException(Exception): diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index e24a67f9d..8b6d403f1 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -19,7 +19,6 @@ class GradientStore(BaseStore): """ self._grads_of_params = dict() # stage 2 - self._partition_grads = partition_grad self._working_index = 0 if partition_grad else self._local_rank # for zero2, it's `param_id: [grad_local_rank]` self.grad_to_param_mapping = dict() @@ -91,7 +90,7 @@ class GradientStore(BaseStore): return grad_list - def get_working_grad_by_param_id(self, param_id) -> Tensor: + def get_working_grad_by_param_id(self, param_id) -> Optional[Tensor]: """ Return the working gradient for the specified parameter. @@ -112,6 +111,7 @@ class GradientStore(BaseStore): def reset_all_gradients(self): self._grads_of_params = dict() + self.grad_to_param_mapping = dict() def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]: """Return the id of a parameter which the gradient slice belongs to diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 1353071c5..458e6e41a 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -21,9 +21,11 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8 +from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, TensorBucket +from .zero_hook import set_all_gather_handle, wait_all_gather_handle class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): @@ -66,7 +68,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def __init__( self, optimizer: Optimizer, - pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None, + pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -84,6 +86,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): dp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights + overlap_allgather: bool = False, fp8_communication: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -92,7 +95,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._logger = get_dist_logger() self._verbose = verbose - if dp_process_group is not None and pg_to_param_list is not None: + if (dp_process_group is not None) and (pg_to_param_list is not None): raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") if pg_to_param_list is None: @@ -123,6 +126,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # communication params self._overlap_communication = overlap_communication + self._overlap_allgather = overlap_allgather self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype self._fp8_communication = fp8_communication @@ -148,6 +152,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # record the padding size of each param self._padding_map = dict() + # padded working param is all-gather buffer and it shares the same memory with working param + self._working_param_to_padded_working_param = dict() # mapping working param and master param self.master_to_working_param = dict() @@ -248,11 +254,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - # reset working params' ptr when no master weights - if self._master_weights == False: - param.data = padding_param[: param.numel()].view(param.shape) + # # reset working params' ptr when no master weights + # if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) + self._working_param_to_padded_working_param[param] = padding_param splited_params = padding_param.split( padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size @@ -261,7 +268,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # use fp32 when master_weights is True if self._master_weights is True: - splited_param_current_rank = splited_params.detach().float().to(device) + splited_param_current_rank = splited_params.detach().clone().float().to(device) else: splited_param_current_rank = splited_params @@ -338,21 +345,21 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) else: flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) + received_grad = torch.zeros_like(flat_grads_list[0]) if self._fp8_communication: reduce_scatter_fp8( - recieved_grad, + received_grad, flat_grads_list, group=bucket_store.torch_pg, ) else: - dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) + if received_grad.dtype != grad_dtype: + received_grad = received_grad.to(grad_dtype) grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank] - self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1) + self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, received_grad, group_id, 1) bucket_store.reset() @@ -562,25 +569,29 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] - if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: - buffer_tensor = torch.empty_like( - torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) - ) - if self._fp8_communication: - all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3") - else: - dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) - working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) - continue - try: - self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) - except RuntimeError: - self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) - self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + padded_working_param = self._working_param_to_padded_working_param[working_param] + if self._overlap_allgather: + handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + set_all_gather_handle(working_param, handle) + else: + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + if self._fp8_communication: + all_gather_into_tensor_flat_fp8( + padded_working_param, param_to_gather, pg, fp8_format="e4m3" + ) + else: + dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + continue + try: + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): - if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) + if not self._overlap_allgather: + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r""" @@ -657,6 +668,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for group_id in range(self.num_param_groups): param_group = self._working_param_groups[group_id] for param in param_group: + if is_moe_tensor(param) and param.requires_grad and param.grad is None: + # TODO better of of doing this + # assign zero grad to unrouted expert to avoid hang during grad reduction + param.grad = torch.zeros_like(param) + if param.requires_grad and param.grad is not None: self._add_to_bucket(param, group_id) @@ -815,8 +831,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): """ for p in model.parameters(): p_id = id(p) - pg = self.param_to_pg[p] if p_id in self.working_to_master_param: + pg = self.param_to_pg[p] master_param = self.working_to_master_param[p_id] padding_size = self.get_param_padding_size(p) working_param = p.data.view(-1) @@ -877,13 +893,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def get_param_grad(self, working_param: nn.Parameter) -> Tensor: grad_store = self.pid_to_grad_store[id(working_param)] - partial_grad = grad_store.get_working_grad_by_param_id(id(working_param)) - if partial_grad is None: + grad = grad_store.get_working_grad_by_param_id(id(working_param)) + if grad is None: return None - tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)] - dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg) - grad_flat = torch.cat(tensor_list, dim=0) - return grad_flat[: working_param.numel()].reshape_as(working_param) + grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) + dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) + return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: working_grads = [] @@ -908,3 +923,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: grad_store = self.pid_to_grad_store[param_id] return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) + + def _force_wait_all_gather(self): + for param in self._working_param_to_padded_working_param.keys(): + wait_all_gather_handle(param) diff --git a/colossalai/zero/low_level/zero_hook.py b/colossalai/zero/low_level/zero_hook.py new file mode 100644 index 000000000..20f9ef31a --- /dev/null +++ b/colossalai/zero/low_level/zero_hook.py @@ -0,0 +1,33 @@ +from typing import List + +from torch._tensor import Tensor + +from colossalai.tensor.param_op_hook import ColoParamOpHook + +_ALL_GATHER_HANDLE = "_all_gather_handle" + + +def wait_all_gather_handle(p): + if hasattr(p, _ALL_GATHER_HANDLE): + handle = getattr(p, _ALL_GATHER_HANDLE) + handle.wait() + delattr(p, _ALL_GATHER_HANDLE) + + +def set_all_gather_handle(p, handle): + setattr(p, _ALL_GATHER_HANDLE, handle) + + +class ZeroOpHook(ColoParamOpHook): + def pre_forward(self, params: List[Tensor]) -> None: + for p in params: + wait_all_gather_handle(p) + + def post_forward(self, params: List[Tensor]) -> None: + pass + + def pre_backward(self, params: List[Tensor]) -> None: + pass + + def post_backward(self, params: List[Tensor]) -> None: + pass diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md new file mode 100644 index 000000000..c11b98043 --- /dev/null +++ b/examples/inference/stable_diffusion/README.md @@ -0,0 +1,22 @@ +## File Structure +``` +|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model. +|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion +|- benchmark_sd3.py: benchmark the performance of our InferenceEngine +|- run_benchmark.sh: run benchmark command +``` +Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/` + +## Run Inference + +The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3. + +For a basic setting, you could run the example by: +```bash +colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world" +``` + +Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs: +```bash +colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL +``` diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py new file mode 100644 index 000000000..19db57c33 --- /dev/null +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -0,0 +1,179 @@ +import argparse +import json +import time +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from diffusers import DiffusionPipeline + +import colossalai +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +_DTYPE_MAPPING = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + + +def log_generation_time(log_data, log_file): + with open(log_file, "a") as f: + json.dump(log_data, f, indent=2) + f.write("\n") + + +def warmup(engine, args): + for _ in range(args.n_warm_up_steps): + engine.generate( + prompts=["hello world"], + generation_config=DiffusionGenerationConfig( + num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0] + ), + ) + + +def profile_context(args): + return ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + if args.profile + else nullcontext() + ) + + +def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None): + log_data = { + "mode": mode, + "model": model_name, + "batch_size": args.batch_size, + "patched_parallel_size": args.patched_parallel_size, + "num_inference_steps": args.num_inference_steps, + "height": h, + "width": w, + "dtype": args.dtype, + "profile": args.profile, + "n_warm_up_steps": args.n_warm_up_steps, + "n_repeat_times": args.n_repeat_times, + "avg_generation_time": avg_time, + "log_message": log_msg, + } + + if args.log: + log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json" + log_generation_time(log_data=log_data, log_file=log_file) + + if args.profile: + file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json" + prof.export_chrome_trace(file) + + +def benchmark_colossalai(rank, world_size, port, args): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + from colossalai.cluster.dist_coordinator import DistCoordinator + + coordinator = DistCoordinator() + + inference_config = InferenceConfig( + dtype=args.dtype, + patched_parallelism_size=args.patched_parallel_size, + ) + engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False) + + warmup(engine, args) + + for h, w in zip(args.height, args.width): + with profile_context(args) as prof: + start = time.perf_counter() + for _ in range(args.n_repeat_times): + engine.generate( + prompts=["hello world"], + generation_config=DiffusionGenerationConfig( + num_inference_steps=args.num_inference_steps, height=h, width=w + ), + ) + end = time.perf_counter() + + avg_time = (end - start) / args.n_repeat_times + log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s" + coordinator.print_on_master(log_msg) + + if dist.get_rank() == 0: + log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof) + + +def benchmark_diffusers(args): + model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda") + + for _ in range(args.n_warm_up_steps): + model( + prompt="hello world", + num_inference_steps=args.num_inference_steps, + height=args.height[0], + width=args.width[0], + ) + + for h, w in zip(args.height, args.width): + with profile_context(args) as prof: + start = time.perf_counter() + for _ in range(args.n_repeat_times): + model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w) + end = time.perf_counter() + + avg_time = (end - start) / args.n_repeat_times + log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s" + print(log_msg) + + log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def benchmark(args): + if args.mode == "colossalai": + spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args) + elif args.mode == "diffusers": + benchmark_diffusers(args) + + +""" +# enable log +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log + +# enable profiler +python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size") + parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps") + parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list") + parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list") + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps") + parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times") + parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler") + parser.add_argument("--log", default=False, action="store_true", help="Enable logging") + parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path") + parser.add_argument( + "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode" + ) + args = parser.parse_args() + benchmark(args) diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py new file mode 100644 index 000000000..14c92501b --- /dev/null +++ b/examples/inference/stable_diffusion/compute_metric.py @@ -0,0 +1,80 @@ +# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py +import argparse +import os + +import numpy as np +import torch +from cleanfid import fid +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio +from torchvision.transforms import Resize +from tqdm import tqdm + + +def read_image(path: str): + """ + input: path + output: tensor (C, H, W) + """ + img = np.asarray(Image.open(path)) + if len(img.shape) == 2: + img = np.repeat(img[:, :, None], 3, axis=2) + img = torch.from_numpy(img).permute(2, 0, 1) + return img + + +class MultiImageDataset(Dataset): + def __init__(self, root0, root1, is_gt=False): + super().__init__() + self.root0 = root0 + self.root1 = root1 + file_names0 = os.listdir(root0) + file_names1 = os.listdir(root1) + + self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")]) + self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")]) + self.is_gt = is_gt + assert len(self.image_names0) == len(self.image_names1) + + def __len__(self): + return len(self.image_names0) + + def __getitem__(self, idx): + img0 = read_image(os.path.join(self.root0, self.image_names0[idx])) + if self.is_gt: + # resize to 1024 x 1024 + img0 = Resize((1024, 1024))(img0) + img1 = read_image(os.path.join(self.root1, self.image_names1[idx])) + + batch_list = [img0, img1] + return batch_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--is_gt", action="store_true") + parser.add_argument("--input_root0", type=str, required=True) + parser.add_argument("--input_root1", type=str, required=True) + args = parser.parse_args() + + psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda") + lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda") + + dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt) + dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) + + progress_bar = tqdm(dataloader) + with torch.inference_mode(): + for i, batch in enumerate(progress_bar): + batch = [img.to("cuda") / 255 for img in batch] + batch_size = batch[0].shape[0] + psnr.update(batch[0], batch[1]) + lpips.update(batch[0], batch[1]) + fid_score = fid.compute_fid(args.input_root0, args.input_root1) + + print("PSNR:", psnr.compute().item()) + print("LPIPS:", lpips.compute().item()) + print("FID:", fid_score) diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt new file mode 100644 index 000000000..c4e74162d --- /dev/null +++ b/examples/inference/stable_diffusion/requirements.txt @@ -0,0 +1,3 @@ +torchvision +torchmetrics +cleanfid diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh new file mode 100644 index 000000000..f3e45a335 --- /dev/null +++ b/examples/inference/stable_diffusion/run_benchmark.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers") +parallelism=(1 2 4 8) +resolutions=(1024 2048 3840) +modes=("colossalai" "diffusers") + +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +for model in "${models[@]}"; do + for p in "${parallelism[@]}"; do + for resolution in "${resolutions[@]}"; do + for mode in "${modes[@]}"; do + if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then + continue + fi + if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then + continue + fi + CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p + + cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution" + + echo "Executing: $cmd" + eval $cmd + done + done + done +done diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py new file mode 100644 index 000000000..9e146c34b --- /dev/null +++ b/examples/inference/stable_diffusion/sd3_generation.py @@ -0,0 +1,81 @@ +import argparse + +from diffusers import DiffusionPipeline +from torch import bfloat16 +from torch import distributed as dist +from torch import float16, float32 + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine + +# For Stable Diffusion 3, we'll use the following configuration +MODEL_CLS = DiffusionPipeline + +TORCH_DTYPE_MAP = { + "fp16": float16, + "fp32": float32, + "bf16": bfloat16, +} + + +def infer(args): + # ============================== + # Launch colossalai, setup distributed environment + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ============================== + # Load model and tokenizer + # ============================== + model_path_or_name = args.model + model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) + + # ============================== + # Initialize InferenceEngine + # ============================== + coordinator.print_on_master(f"Initializing Inference Engine...") + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=args.max_batch_size, + tp_size=args.tp_size, + use_cuda_kernel=args.use_cuda_kernel, + patched_parallelism_size=dist.get_world_size(), + ) + engine = InferenceEngine(model, inference_config=inference_config, verbose=True) + + # ============================== + # Generation + # ============================== + coordinator.print_on_master(f"Generating...") + out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] + if dist.get_rank() == 0: + out.save(f"cat_parallel_size{dist.get_world_size()}.jpg") + coordinator.print_on_master(out) + + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 +# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 +# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, help="Path to the model or model name") + parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt") + parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size") + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + args = parser.parse_args() + + infer(args) diff --git a/examples/inference/stable_diffusion/test_ci.sh b/examples/inference/stable_diffusion/test_ci.sh new file mode 100644 index 000000000..d0189431c --- /dev/null +++ b/examples/inference/stable_diffusion/test_ci.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 9b3a10160..f447adf69 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -1,4 +1,5 @@ import argparse +from contextlib import nullcontext from typing import Callable, List, Union import evaluate @@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam # ============================== @@ -252,10 +254,16 @@ def main(): pad_token_id=data_builder.tokenizer.pad_token_id, ) - if model_name == "gpt2": - model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() - else: - raise RuntimeError + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, (GeminiPlugin)) + else nullcontext() + ) + with init_ctx: + if model_name == "gpt2": + model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + else: + raise RuntimeError # optimizer no_decay = ["bias", "LayerNorm.weight"] diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 8a35db1f7..e530e2d6a 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -98,6 +98,7 @@ def main(): parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--overlap_allgather", action="store_true") args = parser.parse_args() colossalai.launch_from_torch() @@ -199,9 +200,9 @@ def main(): enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", - dp_outside=False, overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -292,7 +293,7 @@ def main(): with get_profile_context( args.profile, args.ignore_steps, - len(dataloader) - 1, + 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index c2883d96c..ca9b63d1a 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -1,4 +1,5 @@ import time +from contextlib import nullcontext import torch import tqdm @@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -62,14 +65,6 @@ def main(): if args.mem_cap > 0: colo_memory_cap(args.mem_cap) - # Build OPT model - config = AutoConfig.from_pretrained(args.model_name_or_path) - model = OPTForCausalLM(config=config) - logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == "torch_ddp_fp16": @@ -82,6 +77,19 @@ def main(): plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + # Build OPT model + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, (GeminiPlugin)) + else nullcontext() + ) + config = AutoConfig.from_pretrained(args.model_name_or_path) + with init_ctx: + model = OPTForCausalLM(config=config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() # Set optimizer optimizer = HybridAdam(model.parameters(), lr=args.learning_rate) diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index b5b50305c..50dfc7bff 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import datasets import torch import transformers @@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -78,14 +82,6 @@ def main(): datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() - # Build OPT model - config = AutoConfig.from_pretrained(args.model_name_or_path) - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) - logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == "torch_ddp_fp16": @@ -110,6 +106,21 @@ def main(): logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + # Build OPT model + config = AutoConfig.from_pretrained(args.model_name_or_path) + # Build OPT model + init_ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + with init_ctx: + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) dataset = NetflixDataset(tokenizer) diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 6b8daf37d..ca4a02cd2 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -113,13 +113,13 @@ class PerformanceEvaluator: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - get_accelerator().synchronize() + # get_accelerator().synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - get_accelerator().synchronize() + # get_accelerator().synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e4affc7f5..93a3690fe 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,3 @@ -diffusers pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 27bbc3769..651eb66e8 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.1.0,<2.3.0 +torch>=2.1.0,<=2.3.0 safetensors einops pydantic @@ -23,3 +23,4 @@ rpyc==6.0.0 fastapi uvicorn==0.29.0 galore_torch +diffusers==0.29.0 diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 05c17f562..4adc38619 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -3,28 +3,17 @@ from .bert import * from .blip2 import * from .bloom import * from .chatglm2 import * +from .command import * +from .deepseek import * from .falcon import * from .gpt import * from .gptj import * from .llama import * +from .mistral import * +from .mixtral import * from .opt import * +from .qwen2 import * from .sam import * from .t5 import * from .vit import * from .whisper import * - -try: - from .mistral import * -except ImportError: - print("This version of transformers doesn't support mistral.") - -try: - from .qwen2 import * -except ImportError: - print("This version of transformers doesn't support qwen2.") - - -try: - from .command import * -except ImportError: - print("This version of transformers doesn't support Command-R.") diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py new file mode 100644 index 000000000..ad73640a5 --- /dev/null +++ b/tests/kit/model_zoo/transformers/deepseek.py @@ -0,0 +1,83 @@ +# modified from tests/kit/model_zoo/transformers/mistral.py +import torch +import transformers +from transformers import AutoConfig + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Mixtral +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mixtral_model = lambda x: x[0].mean() +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + + +def init_deepseek(): + + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + hidden_size=32, + intermediate_size=32, + moe_intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=8, + # vocab_size=2200, + first_k_dense_replace=1, + attn_implementation="flash_attention_2", + torch_dtype="float16", + n_routed_experts=8, + trust_remote_code=True, + ) + + if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + model = transformers.AutoModel.from_config(config, trust_remote_code=True) + + return model + + +model_zoo.register( + name="transformers_deepseek", + model_fn=init_deepseek, + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mixtral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py new file mode 100644 index 000000000..73c0e9e2c --- /dev/null +++ b/tests/kit/model_zoo/transformers/mixtral.py @@ -0,0 +1,85 @@ +# modified from tests/kit/model_zoo/transformers/mistral.py +import torch +import transformers +from transformers import MixtralConfig + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Mixtral +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mixtral_model = lambda x: x[0].mean() +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + +config = MixtralConfig( + hidden_size=32, + intermediate_size=32, + num_attention_heads=8, + num_hidden_layers=2, + vocab_size=1000, + output_router_logits=True, +) + +if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + +model_zoo.register( + name="transformers_mixtral", + model_fn=lambda: transformers.MixtralModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mixtral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +# model_zoo.register( +# name="transformers_mixtral_for_casual_lm", +# model_fn=lambda: transformers.MixtralForCausalLM(config), +# data_gen_fn=data_gen_for_lm, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_mixtral_for_sequence_classification", +# model_fn=lambda: transformers.MixtralForSequenceClassification(config), +# data_gen_fn=data_gen_for_sequence_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn_for_seq_classification, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) diff --git a/tests/test_legacy/test_moe/moe_utils.py b/tests/test_legacy/test_moe/moe_utils.py new file mode 100644 index 000000000..8c133849b --- /dev/null +++ b/tests/test_legacy/test_moe/moe_utils.py @@ -0,0 +1,136 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler +from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce +from colossalai.legacy.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.utils import get_moe_epsize_param_dict +from colossalai.legacy.registry import GRADIENT_HANDLER +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group + + +def delete_moe_info(model): + for _, param in model.named_parameters(): + if hasattr(param, "ep_group"): + delattr(param, "ep_group") + + +class MoeModel(nn.Module): + def __init__(self, ep_group: ProcessGroup = None): + super().__init__() + self.test_embed = nn.Linear(4, 16, bias=False) + self.w1 = torch.nn.Parameter(torch.randn(16, 8)) + if ep_group: + set_moe_tensor_ep_group(self.w1, ep_group) + + def forward(self, x): + x = self.test_embed(x) + x = torch.matmul(x, self.w1) + + return x + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + if dist.get_world_size() > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1]) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: + bucket_allreduce( + param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group + ) + + +def assert_not_equal_in_group(tensor, process_group=None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i + 1] + assert not torch.allclose(a, b), ( + f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" + ) + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_legacy/test_moe/test_grad_handler.py similarity index 98% rename from tests/test_moe/test_grad_handler.py rename to tests/test_legacy/test_moe/test_grad_handler.py index 25e61b091..3a782a6dd 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_legacy/test_moe/test_grad_handler.py @@ -5,7 +5,7 @@ import torch.nn as nn import colossalai from colossalai.accelerator import get_accelerator -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER # from colossalai.shardformer.layer.moe.layers import SparseMLP from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn diff --git a/tests/test_moe/test_moe_group.py b/tests/test_legacy/test_moe/test_moe_group.py similarity index 95% rename from tests/test_moe/test_moe_group.py rename to tests/test_legacy/test_moe/test_moe_group.py index 89baf1d37..68dac4828 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_legacy/test_moe/test_moe_group.py @@ -4,8 +4,8 @@ import torch.nn as nn import colossalai from colossalai.accelerator import get_accelerator -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param +from colossalai.legacy.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.utils import sync_moe_model_param # from colossalai.shardformer.layer.moe import MLPExperts from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_legacy/test_moe/test_moe_hybrid_zero.py similarity index 98% rename from tests/test_moe/test_moe_hybrid_zero.py rename to tests/test_legacy/test_moe/test_moe_hybrid_zero.py index 513c4ebda..fdd6d956e 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_legacy/test_moe/test_moe_hybrid_zero.py @@ -6,7 +6,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeModel diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_legacy/test_moe/test_moe_load_balance.py similarity index 99% rename from tests/test_moe/test_moe_load_balance.py rename to tests/test_legacy/test_moe/test_moe_load_balance.py index ddd3ea368..adf2dbc1c 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_legacy/test_moe/test_moe_load_balance.py @@ -6,7 +6,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.moe.manager import MOE_MANAGER +from colossalai.legacy.moe.manager import MOE_MANAGER # from colossalai.shardformer.layer.moe import apply_load_balance from colossalai.tensor.moe_tensor.api import is_moe_tensor diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index b8daf775d..1ae17025d 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -9,7 +9,8 @@ from torch.optim import AdamW import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_checkpoint_io.utils import shared_tempdir @@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)] test_configs = [ { "lora_config": lora_config, @@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type # test fwd bwd correctness test_model = model_load + if isinstance(model_load, HybridParallelModule): + model_load = model_load.module.module model_copy = copy.deepcopy(model_load) data = data_gen_fn() diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 131932dcb..8c411a33f 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,142 +1,8 @@ import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup -from torch.testing import assert_close - -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler -from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce -from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_moe_epsize_param_dict - -# from colossalai.shardformer.layer.moe import SparseMLP -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group -def delete_moe_info(model): - for _, param in model.named_parameters(): - if hasattr(param, "ep_group"): - delattr(param, "ep_group") - - -class MoeModel(nn.Module): - def __init__(self, ep_group: ProcessGroup = None): - super().__init__() - self.test_embed = nn.Linear(4, 16, bias=False) - self.w1 = torch.nn.Parameter(torch.randn(16, 8)) - if ep_group: - set_moe_tensor_ep_group(self.w1, ep_group) - - def forward(self, x): - x = self.test_embed(x) - x = torch.matmul(x, self.w1) - - return x - - -@GRADIENT_HANDLER.register_module -class MoeGradientHandler(BaseGradientHandler): - """A helper class to handle all-reduce operations in a data parallel group and - moe model parallel. A all-reduce collective communication will be operated in - :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are - the same type to improve the efficiency of communication. - - Args: - model (Module): Model where the gradients accumulate. - optimizer (Optimizer): Optimizer for updating the parameters. - """ - - def __init__(self, model, optimizer=None): - super().__init__(model, optimizer) - - def handle_gradient(self): - """A method running an all-reduce operation in a data parallel group. - Then running an all-reduce operation for all parameters in experts - across moe model parallel group - """ - if dist.get_world_size() > 1: - epsize_param_dict = get_moe_epsize_param_dict(self._model) - - # epsize is 1, indicating the params are replicated among processes in data parallelism - # use the ParallelMode.DATA to get data parallel group - # reduce gradients for all parameters in data parallelism - if 1 in epsize_param_dict: - bucket_allreduce(param_list=epsize_param_dict[1]) - - for ep_size in epsize_param_dict: - if ep_size != 1 and ep_size != MOE_MANAGER.world_size: - bucket_allreduce( - param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group - ) - - -def assert_not_equal_in_group(tensor, process_group=None): - # all gather tensors from different ranks - world_size = dist.get_world_size(process_group) - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(tensor_list, tensor, group=process_group) - - # check if they are equal one by one - for i in range(world_size - 1): - a = tensor_list[i] - b = tensor_list[i + 1] - assert not torch.allclose(a, b), ( - f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" - ) - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - local_model (MoeModule) - ep_model (MoeModule) - """ - for (local_name, local_param), (ep_name, ep_param) in zip( - local_model.named_parameters(), ep_model.named_parameters() - ): - if "experts" not in local_name: - if assert_grad_flag: - assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" - assert torch.allclose(local_param.grad, ep_param.grad) - else: - local_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - if assert_grad_flag: - assert torch.allclose(local_param, all_param) - assert torch.allclose(local_param.grad, all_grad) - else: - local_param.data.copy_(all_param.data) +def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""): + assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}" def loose_close(a, b, dtype: torch.dtype = torch.float32): @@ -148,8 +14,18 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): elif dtype is torch.bfloat16: rtol = 4e-3 atol = 4e-3 + else: + assert dtype is torch.float32 + rtol = 1e-05 + atol = 1e-08 a = a.detach().to(dtype) b = b.detach().to(dtype).to(a.device) - assert_close(a, b, rtol=rtol, atol=atol) + return torch.allclose(a, b, rtol=rtol, atol=atol) + + +def check_model_equal(model1, model2): + assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) + for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): + assert_loose_close(p1, p2, p1.dtype) diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py new file mode 100644 index 000000000..d18ba2eac --- /dev/null +++ b/tests/test_moe/test_deepseek_layer.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from torch.testing import assert_close +from transformers import AutoConfig, AutoModel + +import colossalai +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_deepseek_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + zero_stage=1, + ep_size=dist.get_world_size(), + ) + + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + num_hidden_layers=1, + n_routed_experts=n_experts, + num_experts_per_tok=top_k, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + first_k_dense_replace=0, + num_attention_heads=2, + trust_remote_code=True, + ) + torch.manual_seed(0) + # get the moe layer in auto model + orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda() + x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() + orig_output = orig_model(x) + model = deepcopy(orig_model) + model = EPDeepseekMoE.from_native_module( + model, + ep_group=plugin.ep_group, + moe_dp_group=plugin.moe_dp_group, + tp_group=plugin.tp_group, + ) + ep_output = model(x) + assert_close(orig_output, ep_output) + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss) + name_to_p = {n: p for n, p in orig_model.named_parameters()} + for n, ep_p in model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch(rank, world_size, "localhost", port) + check_deepseek_moe_layer() + + +@pytest.mark.skip("tested in corresponding sharderformer") +@pytest.mark.parametrize("world_size", [2]) +def test_deepseek_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_deepseek_moe_layer(2) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 28e6db441..c81023988 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -4,8 +4,6 @@ import pytest import torch from colossalai.accelerator import get_accelerator - -# from colossalai.moe import SparseMLP from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum NUM_EXPERTS = 4 diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index b7b0322e0..bc41ac4f3 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -23,6 +23,7 @@ def check_mixtral_moe_layer(): precision="bf16", tp_size=1, pp_size=1, + zero_stage=1, ep_size=dist.get_world_size(), ) config = MixtralConfig( @@ -36,7 +37,12 @@ def check_mixtral_moe_layer(): x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() orig_output, orig_logits = orig_model(x) model = deepcopy(orig_model) - model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group) + model = EPMixtralSparseMoeBlock.from_native_module( + model, + ep_group=plugin.ep_group, + tp_group=plugin.tp_group, + moe_dp_group=plugin.moe_dp_group, + ) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) assert_close(orig_output, ep_output) @@ -57,7 +63,8 @@ def run_dist(rank: int, world_size: int, port: int): check_mixtral_moe_layer() -@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.skip("tested in corresponding sharderformer") +@pytest.mark.parametrize("world_size", [2]) def test_mixtral_moe_layer(world_size: int): spawn(run_dist, world_size) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 249dd4b97..89f5d1c64 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -6,30 +6,23 @@ from copy import deepcopy import pytest import torch import torch.distributed as dist -from torch.optim import Adam +from torch.optim import SGD, Adam from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.checkpoint_io import MoECheckpointIO -from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import parameterize, spawn +from colossalai.testing.random import seed_all from colossalai.testing.utils import spawn +from tests.test_moe.moe_utils import check_model_equal tokens, n_experts = 7, 4 hidden_size = 8 top_k = 2 -def check_model_equal(model1, model2): - assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) - for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): - if not torch.equal(p1.half(), p2.half()): - print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}") - raise AssertionError(f"Model parameter {name} is not equal") - - def get_optimizer_snapshot(optim): state = {id(k): deepcopy(v) for k, v in optim.state.items()} param_groups = [] @@ -77,36 +70,44 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou raise AssertionError(f"A total of {count} optim states are not equal") -def check_mixtral_moe_layer(): +@parameterize( + "test_config", + [ + [ + MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, + num_hidden_layers=2, + ), + MixtralForCausalLM, + ], + ], +) +def check_moe_checkpoint(test_config): + dtype, precision = torch.float16, "fp16" + config, model_cls = test_config + torch.cuda.set_device(dist.get_rank()) + context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() with context as f: - torch.cuda.set_device(dist.get_rank()) if dist.get_rank() == 0: broadcast_objects = [f] # any picklable object else: broadcast_objects = [None] dist.broadcast_object_list(broadcast_objects, src=0) - config = MixtralConfig( - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_local_experts=n_experts, - num_experts_per_tok=top_k, - num_attention_heads=2, - num_key_value_heads=2, - ) - torch.manual_seed(0) input_ids = torch.randint(0, 100, (2, tokens)).cuda() - orig_model = MixtralForCausalLM(config).cuda() + orig_model = model_cls(config).cuda().to(dtype) + + seed_all(10086) model = deepcopy(orig_model) - optimizer = Adam(model.parameters(), lr=1e-3) + optimizer = SGD(model.parameters(), lr=1e-3) plugin = MoeHybridParallelPlugin( - pp_size=2, - ep_size=2, - tp_size=1, - checkpoint_io=MoECheckpointIO, - microbatch_size=1, - zero_stage=1, + pp_size=2, ep_size=2, tp_size=1, microbatch_size=1, zero_stage=1, precision=precision ) booster = Booster(plugin=plugin) model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) @@ -120,7 +121,6 @@ def check_mixtral_moe_layer(): lambda outputs, inputs: outputs.loss, optimizer, ) - tmpdirname = broadcast_objects[0] model_dir = os.path.join(tmpdirname, "mixtral_model") hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model") @@ -129,13 +129,12 @@ def check_mixtral_moe_layer(): booster.save_model(model, model_dir, shard=True) dist.barrier() if dist.get_rank() == 0: - saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda() + saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype) check_model_equal(orig_model, saved_model) - # check_model_equal(model, saved_model) saved_model.save_pretrained(hf_model_dir) dist.barrier() # check load model - new_model = MixtralForCausalLM(config).cuda() + new_model = model_cls(config).cuda().to(dtype) new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) booster.load_model(new_model, hf_model_dir) @@ -163,7 +162,7 @@ def check_mixtral_moe_layer(): def run_dist(rank: int, world_size: int, port: int): colossalai.launch(rank, world_size, "localhost", port) - check_mixtral_moe_layer() + check_moe_checkpoint() # Test EP + ZeRO + PP diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 9bc11033a..e6d2609ee 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -1,238 +1,132 @@ -import os -import warnings -from typing import Dict +from copy import deepcopy import pytest import torch import torch.distributed as dist +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param +from colossalai.booster.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close -# from colossalai.shardformer.layer import SparseMLP -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from tests.test_moe.moe_utils import MoeGradientHandler +NUM_BATCH = 4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 2 -def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from local model +@parameterize("stage", [1]) +@parameterize("ep_size", [2]) +def run_zero_with_original_model(stage: int, ep_size: int): + tp_size = dist.get_world_size() // ep_size + dtype = torch.bfloat16 - Args: - tp_model (MoeModule) - local_model (MoeModule) - """ - for (tp_name, tp_param), (local_name, local_param) in zip( - tp_model.named_parameters(), local_model.named_parameters() - ): - assert tp_name == local_name - if not is_moe_tensor(tp_param): - if assert_grad_flag: - assert torch.allclose(tp_param, local_param) - assert torch.allclose(tp_param.grad, local_param.grad) - else: - tp_param.data.copy_(local_param.data) - continue + rank = torch.distributed.get_rank() + torch.cuda.set_device(dist.get_rank()) - tp_rank = get_ep_rank(tp_param) - tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0] - tp_slice = [slice(None)] * tp_dim + [ - slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) - ] + seed_all(10086) - if assert_grad_flag: - assert torch.allclose(tp_param, local_param[tuple(tp_slice)]) - assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)]) - else: - tp_param.data.copy_(local_param[tuple(tp_slice)].data) - - -def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - tp_model (MoeModule) - ep_model (MoeModule) - """ - for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): - assert tp_name == ep_name - if not is_moe_tensor(tp_param): - if assert_grad_flag: - assert torch.allclose(tp_param, ep_param) - assert torch.allclose(tp_param.grad, ep_param.grad) - else: - tp_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - # get tp param - tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1 - tp_rank = get_ep_rank(tp_param) - tp_slice = [slice(None)] * tp_dim + [ - slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) - ] - new_tp_param = all_param[tuple(tp_slice)] - if assert_grad_flag: - new_grad = all_grad[tuple(tp_slice)] - if assert_grad_flag: - assert torch.allclose(tp_param, new_tp_param) - assert torch.allclose(tp_param.grad, new_grad) - else: - tp_param.data.copy_(new_tp_param.data) - - -def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None: - """Sync the parameters of tp model from ep model - - Args: - local_model (MoeModule) - ep_model (MoeModule) - """ - for (local_name, local_param), (ep_name, ep_param) in zip( - local_model.named_parameters(), ep_model.named_parameters() - ): - assert local_name == ep_name - if "experts" not in local_name: - if assert_grad_flag: - assert torch.allclose(local_param, ep_param) - assert torch.allclose(local_param.grad, ep_param.grad) - else: - local_param.data.copy_(ep_param.data) - continue - - # gather param from ep model - param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) - all_param = torch.cat(param_list, dim=0) - if assert_grad_flag: - grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] - dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) - all_grad = torch.cat(grad_list, dim=0) - - if assert_grad_flag: - assert torch.allclose(local_param, all_param) - assert torch.allclose(local_param.grad, all_grad) - else: - local_param.data.copy_(all_param.data) - - -def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict): - assert batch_size % world_size == 0 - - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel=None) - local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="EP") - enable_hierarchical_comm = config.get("enable_hierarchical_comm", False) - if enable_hierarchical_comm: - os.environ["LOCAL_WORLD_SIZE"] = str(world_size) - ep_model = SparseMLP( - num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2, - enable_hierarchical_comm=enable_hierarchical_comm, + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=2, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, ) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="TP") - tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - ep_model = ep_model.to(get_accelerator().get_current_device()) - tp_model = tp_model.to(get_accelerator().get_current_device()) - local_model = local_model.to(get_accelerator().get_current_device()) + torch_model = MixtralModel(config).to(dtype).cuda() - # sync ep param - sync_moe_model_param(ep_model) - dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) - ep_grad_handler = MoeGradientHandler(ep_model) - # sync local param - sync_local_from_ep(local_model, ep_model) - # sync tp param - sync_tp_from_ep(tp_model, ep_model) - tp_grad_handler = MoeGradientHandler(tp_model) - - rank = dist.get_rank() - input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device()) - micro_batch_size = batch_size // world_size - index = rank * micro_batch_size - # NOTE: ep & tp takes in sharded data for each process - shard_data = input_data.detach()[index : index + micro_batch_size] - - out_local = local_model(input_data) - MOE_MANAGER.reset_loss() - out_tp = tp_model(shard_data) - MOE_MANAGER.reset_loss() - out_ep = ep_model(shard_data) - MOE_MANAGER.reset_loss() - - assert torch.allclose( - out_tp, out_ep, atol=1e-6 - ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" - try: - out_local_slice = out_local[index : index + micro_batch_size] - assert torch.allclose( - out_ep, out_local_slice, atol=1e-6 - ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" - except AssertionError: - """ - e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 - router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 - However, in ep mode, there are 2 separate routers dealing with sharded data. - Assume router 0 handles token [01] and router 1 handles token [23]. - Note that for each router the capacity is only 1 !!! - Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both. - The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. - """ - warnings.warn( - "EP & TP may result in different behavior from local model. " "Please check the comments for details." + zero_model = deepcopy(torch_model).to(dtype) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + moe_booster = Booster( + plugin=MoeHybridParallelPlugin( + tp_size=tp_size, + moe_tp_size=tp_size, + pp_size=1, + ep_size=ep_size, + zero_stage=stage, + overlap_communication=False, + initial_scale=1, ) + ) + zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer) - out_local.mean().backward() - out_tp.mean().backward() - tp_grad_handler.handle_gradient() - out_ep.mean().backward() - ep_grad_handler.handle_gradient() - - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) - sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) - try: - sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - except AssertionError: - warnings.warn( - "EP & TP may result in different behavior from local model. " "Please check the comments for details." + hybird_booster = Booster( + plugin=HybridParallelPlugin( + tp_size=tp_size, + pp_size=1, + zero_stage=stage, + overlap_communication=False, + initial_scale=1, ) + ) + hybrid_model, hybrid_optimizer, _, _, _ = hybird_booster.boost( + torch_model, torch.optim.SGD(torch_model.parameters(), lr=1) + ) + # create different input + seed_all(1453 + rank) + + hybrid_model.train() + zero_model.train() + for _ in range(2): + # zero-dp forward + input_data = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + # zero-dp backward + zero_optimizer.backward(zero_output) + # torch-ddp forward + hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + assert_loose_close(zero_output, hybrid_output, dtype=dtype) + # torch-ddp backward + hybrid_optimizer.backward(hybrid_output) + + # check grad + name_to_p = {n: p for n, p in hybrid_model.named_parameters()} + for n, p in zero_model.named_parameters(): + zero_grad = zero_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + name_to_p[n].grad = torch.zeros_like(name_to_p[n]) + continue + if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe + continue + assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n) + + # zero-dp step + zero_optimizer.step() + + # original model step + hybrid_optimizer.step() + + # check updated param + for n, p in zero_model.named_parameters(): + if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe + continue + assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n) + + print(f"{dist.get_rank()} test passed") -@pytest.mark.skip(reason="moe need to be refactored") +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.skip("tested in corresponding sharderformer") @pytest.mark.dist -@pytest.mark.parametrize("num_experts", [4, 64]) -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize( - "config", - [ - {"enable_hierarchical_comm": False}, - {"enable_hierarchical_comm": True}, - ], -) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): - spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) +def test_moe_ep_tp(world_size): + spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_ep_tp(num_experts=8, batch_size=32, dim=32) + test_moe_ep_tp(world_size=4) diff --git a/tests/test_moe/test_moe_ep_zero.py b/tests/test_moe/test_moe_ep_zero.py new file mode 100644 index 000000000..2d4e638b6 --- /dev/null +++ b/tests/test_moe/test_moe_ep_zero.py @@ -0,0 +1,119 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close + +NUM_BATCH = 4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 2 +TOP_K = 1 + + +@parameterize("stage", [1]) +@parameterize("ep_size", [2, 4]) +def run_zero_with_original_model(stage: int, ep_size: int): + dtype = torch.bfloat16 + + rank = torch.distributed.get_rank() + torch.cuda.set_device(dist.get_rank()) + + plugin = MoeHybridParallelPlugin( + pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1 + ) + booster = Booster(plugin=plugin) + + seed_all(10086) + + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=2, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + ) + + torch_model = MixtralModel(config).to(dtype).cuda() + + zero_model = deepcopy(torch_model).to(dtype) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + ddp_model = DDP( + torch_model.cuda(), + process_group=plugin.dp_group, + find_unused_parameters=True, # important for torch ddp, not all experts are routed + ).cuda() + ddp_optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1) + + # create different input + seed_all(1453 + rank) + + ddp_model.train() + zero_model.train() + for _ in range(2): + # zero-dp forward + input_data = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + # zero-dp backward + zero_optimizer.backward(zero_output) + + # torch-ddp forward + ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + assert_loose_close(zero_output, ddp_output, dtype=dtype) + # torch-ddp backward + ddp_output.backward() + + # check grad + name_to_p = {n: p for n, p in ddp_model.named_parameters()} + for n, p in zero_model.named_parameters(): + zero_grad = zero_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) + continue + assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n) + + # zero-dp step + zero_optimizer.step() + + # original model step + ddp_optimizer.step() + + # check updated param + for n, p in zero_model.named_parameters(): + assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n) + + print(f"{dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.skip("tested in corresponding sharderformer") +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_ep_zero(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_ep_zero(world_size=4) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py deleted file mode 100644 index 042b3d8ae..000000000 --- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py +++ /dev/null @@ -1,132 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP -from transformers.models.mixtral.configuration_mixtral import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - -import colossalai -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock -from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from colossalai.zero import LowLevelZeroOptimizer -from tests.test_moe.moe_utils import loose_close - -tokens, n_experts = 7, 4 -hidden_size = 8 -top_k = 2 - - -def split_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -@parameterize("dtype", [torch.float16, torch.bfloat16]) -@parameterize("master_weights", [True, False]) -@parameterize("stage", [1, 2]) -def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int): - rank = torch.distributed.get_rank() - torch.cuda.set_device(dist.get_rank()) - plugin = MoeHybridParallelPlugin( - tp_size=1, - pp_size=1, - ep_size=dist.get_world_size() // 2, - ) - - seed_all(10086) - config = MixtralConfig( - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_local_experts=n_experts, - num_experts_per_tok=top_k, - ) - - orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda() - - ori_model = DDP(orig_model.cuda(), static_graph=True).cuda() - - zero_model = deepcopy(orig_model).to(dtype) - zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group) - - zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) - pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []} - for p in zero_model.parameters(): - if is_moe_tensor(p): - pg_param_list[plugin.moe_dp_group].append(p) - else: - pg_param_list[plugin.global_dp_group].append(p) - - zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, - pg_to_param_list=pg_param_list, - master_weights=master_weights, - initial_scale=1, - overlap_communication=False, - partition_grad=True, - ) - - ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1) - - # create - seed_all(1453 + rank) - - for _ in range(2): - # zero-dp forward - input_data = torch.rand(1, tokens, hidden_size).cuda() - zero_output, zero_logits = zero_model(input_data.to(dtype)) - - # torch-ddp forward - ori_output, ori_logits = ori_model(input_data.to(dtype)) - loose_close(zero_output, ori_output, dtype=dtype) - - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) - - # torch-ddp backward - ori_output.mean().backward() - - # check grad - name_to_p = {n: p for n, p in ori_model.module.named_parameters()} - for n, p in zero_model.named_parameters(): - zero_grad = zero_optimizer.get_param_grad(p) - if name_to_p[n].grad is None: - assert zero_grad is None - continue - - loose_close(zero_grad, name_to_p[n].grad, dtype=dtype) - - # zero-dp step - zero_optimizer.step() - - # original model step - ori_optimizer.step() - - # check updated param - for n, p in zero_model.named_parameters(): - loose_close(p.data, name_to_p[n].data, dtype=dtype) - - -def run_dist(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_with_original_model(world_size=world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_model(world_size=4) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 1ffcc541a..190fee129 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,6 +1,6 @@ import copy from contextlib import nullcontext -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Type import torch import torch.distributed as dist @@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""): def build_model_from_hybrid_plugin( - model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam + model_fn: Callable, + loss_fn: Callable, + test_config: Dict[str, Any], + optim_class=Adam, + sharded_optim_class=Adam, + pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin, ): use_lazy_init = False if "use_lazy_init" in test_config: @@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin( else: org_optimizer = optim_class(org_model.parameters(), lr=1e-3) sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn - plugin = HybridParallelPlugin(**test_config) + plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 6ce020b68..92c077950 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -136,6 +136,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 4d66692a4..3281b50e1 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -58,6 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # Check the grad when using ZeRO-1 and ZeRO-2 if ( booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): @@ -154,6 +155,45 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py new file mode 100644 index 000000000..46da4522f --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -0,0 +1,196 @@ +import os +import shutil +from copy import deepcopy +from typing import Tuple + +import pytest +import torch +import torch.distributed +import torch.distributed as dist +from transformers import AutoConfig, AutoModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close, check_model_equal + +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2 +NUM_LAYERS = 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 2 + + +CHECKED_CONFIG = [ # FOR_WORLD=4 + (1, 4, 1, 1, 1), + (1, 1, 4, 1, 1), + (1, 1, 1, 4, 1), + (1, 1, 1, 1, 4), + (0, 1, 4, 1, 1), + (0, 1, 1, 4, 1), + (0, 1, 1, 1, 4), + (1, 2, 1, 1, 1), +] + + +@parameterize( + "config", + [ + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), + ], +) +def run_zero_with_original_model(config: Tuple[int, ...]): + stage, ep_size, pp_size, tp_size, sp_size = config + world_size = dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + enable_flash_attention=sp_size > 1, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + ) + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-moe-16b-base", + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=4, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + first_k_dense_replace=1, + attn_implementation="flash_attention_2", + torch_dtype="float16", + n_routed_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + trust_remote_code=True, + ) + + # init model with the same seed + seed_all(10086) + + torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x[0].mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + if booster.plugin.stage_manager.is_last_stage(): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + + # broadcast along pp axis + dist.broadcast( + parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group + ) + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + + # use checkpoint to load sharded zero model + model_dir = "./test_deepseek" + if rank == world_size - 1: + os.makedirs(model_dir, exist_ok=True) + + dist.barrier() + booster.save_model(parallel_model, model_dir, shard=True) + dist.barrier() + + saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() + check_model_equal(torch_model, saved_model) + dist.barrier() + + if rank == world_size - 1: + shutil.rmtree(model_dir) + + print(f"rank {dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_deepseek(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_deepseek(world_size=4) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8fe18f69b..88e54176b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): + master2working = sharded_optimizer.get_master_to_working_map() for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer.master_to_working_param[id(p2)] + working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( 0 @@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, @@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 1, "pp_size": 1, @@ -245,7 +247,6 @@ def run_llama_test(test_config): except Exception as e: print(f"Failed config: {test_config}") raise e - clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py new file mode 100644 index 000000000..de09eedcb --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -0,0 +1,190 @@ +import os +import shutil +from copy import deepcopy +from typing import Tuple + +import pytest +import torch +import torch.distributed +import torch.distributed as dist +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close, check_model_equal + +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 + +CHECKED_CONFIG = [ # FOR WORLD=4 + (0, 1, 4, 1, 1), + (0, 1, 1, 4, 1), + (0, 1, 1, 1, 4), + (1, 4, 1, 1, 1), + (1, 1, 4, 1, 1), + (1, 1, 1, 4, 1), + (1, 1, 1, 1, 4), + (1, 2, 1, 1, 1), +] + + +@parameterize( + "config", + [ + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), + ], +) +def run_zero_with_original_model(config: Tuple[int, ...]): + stage, ep_size, pp_size, tp_size, sp_size = config + world_size = dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + ) + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + attn_implementation="flash_attention_2", + ) + + # init model with the same seed + seed_all(10086) + + torch_model = MixtralModel(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + if booster.plugin.stage_manager.is_last_stage(): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + + # broadcast along pp axis + dist.broadcast( + parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group + ) + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + + # use checkpoint to load sharded zero model + model_dir = "./test_mixtral" + if rank == world_size - 1: + os.makedirs(model_dir, exist_ok=True) + + dist.barrier() + booster.save_model(parallel_model, model_dir, shard=True) + dist.barrier() + + saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype) + check_model_equal(torch_model, saved_model) + dist.barrier() + + if rank == world_size - 1: + shutil.rmtree(model_dir) + + print(f"rank {dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_mixtral(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral(world_size=4) diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index 166b31df9..c87415b75 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -135,6 +135,68 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 1, "pp_size": 2, @@ -151,8 +213,11 @@ def run_qwen2_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() @@ -197,7 +262,11 @@ def run_qwen2_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index ed12bb72d..94db70ca5 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc(): zero1_optimizer.step() zero2_optimizer.step() + zero1_optimizer._force_wait_all_gather() + zero2_optimizer._force_wait_all_gather() + # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert not hasattr(z1p, "_all_gather_handle") assert torch.equal(z1p.data, z2p.data) diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index cda51c4ef..368c782fe 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -190,6 +190,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): # torch ddp step torch_optimizer.step() + zero_optimizer._force_wait_all_gather() + # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): loose_close(p, z1p, dtype=dtype) diff --git a/version.txt b/version.txt index 1d0ba9ea1..2b7c5ae01 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.0 +0.4.2