Browse Source

[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
pull/6024/head
Wang Binluo 3 months ago committed by GitHub
parent
commit
eea37da6fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      .compatibility
  2. 4
      .cuda_ext.json
  3. 2
      .github/workflows/build_on_pr.yml
  4. 2
      .github/workflows/build_on_schedule.yml
  5. 3
      .pre-commit-config.yaml
  6. 1
      applications/ColossalChat/.gitignore
  7. 2
      applications/ColossalChat/README.md
  8. 21
      applications/ColossalChat/coati/dataset/tokenization_utils.py
  9. 16
      applications/ColossalChat/coati/models/loss.py
  10. 7
      applications/ColossalChat/coati/models/utils.py
  11. 9
      applications/ColossalChat/coati/trainer/dpo.py
  12. 37
      applications/ColossalChat/coati/trainer/kto.py
  13. 12
      applications/ColossalChat/coati/trainer/orpo.py
  14. 12
      applications/ColossalChat/coati/trainer/ppo.py
  15. 14
      applications/ColossalChat/coati/trainer/sft.py
  16. 29
      applications/ColossalChat/examples/README.md
  17. 5
      applications/ColossalChat/examples/inference/inference.py
  18. 8
      applications/ColossalChat/examples/training_scripts/train_dpo.py
  19. 2
      applications/ColossalChat/examples/training_scripts/train_kto.py
  20. 2
      applications/ColossalChat/examples/training_scripts/train_orpo.py
  21. 4
      applications/ColossalChat/examples/training_scripts/train_ppo.py
  22. 2
      applications/ColossalChat/examples/training_scripts/train_sft.py
  23. 2
      applications/ColossalChat/requirements.txt
  24. 36
      applications/ColossalChat/tests/test_train.sh
  25. 2
      applications/README.md
  26. 18
      colossalai/booster/booster.py
  27. 26
      colossalai/booster/plugin/gemini_plugin.py
  28. 81
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  29. 26
      colossalai/booster/plugin/low_level_zero_plugin.py
  30. 25
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  31. 2
      colossalai/booster/plugin/torch_ddp_plugin.py
  32. 15
      colossalai/booster/plugin/torch_fsdp_plugin.py
  33. 4
      colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
  34. 4
      colossalai/lazy/pretrained.py
  35. 2
      colossalai/legacy/moe/openmoe/model/openmoe_policy.py
  36. 3
      colossalai/legacy/nn/layer/parallel_1d/_operation.py
  37. 5
      colossalai/logging/logger.py
  38. 1
      colossalai/pipeline/schedule/interleaved_pp.py
  39. 1
      colossalai/pipeline/schedule/one_f_one_b.py
  40. 4
      colossalai/shardformer/layer/__init__.py
  41. 36
      colossalai/shardformer/layer/_operation.py
  42. 916
      colossalai/shardformer/layer/attn.py
  43. 17
      colossalai/shardformer/layer/linear.py
  44. 167
      colossalai/shardformer/layer/loss.py
  45. 25
      colossalai/shardformer/layer/normalization.py
  46. 2
      colossalai/shardformer/layer/qkv_fused_linear.py
  47. 198
      colossalai/shardformer/layer/utils.py
  48. 20
      colossalai/shardformer/modeling/chatglm2.py
  49. 8
      colossalai/shardformer/modeling/command.py
  50. 10
      colossalai/shardformer/modeling/deepseek.py
  51. 145
      colossalai/shardformer/modeling/llama.py
  52. 4
      colossalai/shardformer/modeling/mixtral.py
  53. 1
      colossalai/shardformer/policies/base_policy.py
  54. 31
      colossalai/shardformer/policies/command.py
  55. 52
      colossalai/shardformer/policies/llama.py
  56. 2
      colossalai/shardformer/policies/mistral.py
  57. 15
      colossalai/shardformer/policies/mixtral.py
  58. 12
      colossalai/shardformer/shard/shard_config.py
  59. 2
      colossalai/testing/utils.py
  60. 12
      colossalai/zero/gemini/gemini_optimizer.py
  61. 7
      docs/source/en/basics/launch_colossalai.md
  62. 6
      docs/source/zh-Hans/basics/launch_colossalai.md
  63. 33
      examples/language/llama/benchmark.py
  64. 2
      examples/language/opt/README.md
  65. 24
      examples/language/performance_evaluator.py
  66. 2
      examples/tutorial/opt/opt/README.md
  67. 8
      extensions/pybind/flash_attention/flash_attention_dao_cuda.py
  68. 2
      requirements/requirements-test.txt
  69. 2
      requirements/requirements.txt
  70. 4
      tests/kit/model_zoo/__init__.py
  71. 12
      tests/kit/model_zoo/transformers/command.py
  72. 41
      tests/kit/model_zoo/transformers/llama.py
  73. 2
      tests/kit/model_zoo/transformers/mistral.py
  74. 2
      tests/kit/model_zoo/transformers/mixtral.py
  75. 12
      tests/kit/model_zoo/transformers/qwen2.py
  76. 2
      tests/test_booster/test_plugin/test_3d_plugin.py
  77. 2
      tests/test_booster/test_plugin/test_low_level_zero_plugin.py
  78. 2
      tests/test_booster/test_plugin/test_torch_ddp_plugin.py
  79. 2
      tests/test_checkpoint_io/test_gemini_checkpoint_io.py
  80. 2
      tests/test_checkpoint_io/test_gemini_torch_compability.py
  81. 2
      tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
  82. 2
      tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
  83. 2
      tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
  84. 14
      tests/test_lazy/test_models.py
  85. 2
      tests/test_lora/test_lora.py
  86. 17
      tests/test_pipeline/test_schedule/test_interleaved.py
  87. 17
      tests/test_pipeline/test_schedule/test_oneF_oneB.py
  88. 3
      tests/test_shardformer/test_flash_attention.py
  89. 186
      tests/test_shardformer/test_layer/test_ring_attn.py
  90. 27
      tests/test_shardformer/test_model/_utils.py
  91. 4
      tests/test_shardformer/test_model/test_shard_command.py
  92. 113
      tests/test_shardformer/test_model/test_shard_llama.py

1
.compatibility

@ -1,3 +1,4 @@
2.1.0-12.1.0 2.1.0-12.1.0
2.2.2-12.1.0 2.2.2-12.1.0
2.3.0-12.1.0 2.3.0-12.1.0
2.4.0-12.4.1

4
.cuda_ext.json

@ -5,8 +5,8 @@
"cuda_image": "hpcaitech/cuda-conda:12.1" "cuda_image": "hpcaitech/cuda-conda:12.1"
}, },
{ {
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118", "torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
"cuda_image": "hpcaitech/cuda-conda:11.8" "cuda_image": "hpcaitech/cuda-conda:12.4"
} }
] ]
} }

2
.github/workflows/build_on_pr.yml

@ -141,7 +141,7 @@ jobs:
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
BUILD_EXT=1 pip install -v -e . BUILD_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache - name: Store Colossal-AI Cache
run: | run: |

2
.github/workflows/build_on_schedule.yml

@ -57,7 +57,7 @@ jobs:
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
BUILD_EXT=1 pip install -v -e . BUILD_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Unit Testing - name: Unit Testing
if: steps.check-avai.outputs.avai == 'true' if: steps.check-avai.outputs.avai == 'true'

3
.pre-commit-config.yaml

@ -12,9 +12,10 @@ repos:
hooks: hooks:
- id: isort - id: isort
name: sort all imports (python) name: sort all imports (python)
args: ["--profile", "black"] # avoid conflict with black
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2 rev: 24.8.0
hooks: hooks:
- id: black - id: black
name: black formatter name: black formatter

1
applications/ColossalChat/.gitignore vendored

@ -151,6 +151,7 @@ examples/training_scripts/wandb
examples/training_scripts/output examples/training_scripts/output
examples/awesome-chatgpt-prompts/ examples/awesome-chatgpt-prompts/
examples/inference/round.txt
temp/ temp/
# ColossalChat # ColossalChat

2
applications/ColossalChat/README.md

@ -121,7 +121,7 @@ cd $COLOSSAL_AI_ROOT
BUILD_EXT=1 pip install . BUILD_EXT=1 pip install .
# Install ColossalChat # Install ColossalChat
cd $COLOSSAL_AI_ROOT/applications/Chat cd $COLOSSAL_AI_ROOT/applications/ColossalChat
pip install . pip install .
``` ```

21
applications/ColossalChat/coati/dataset/tokenization_utils.py

@ -49,6 +49,10 @@ def tokenize_sft(
messages = data_point["messages"] messages = data_point["messages"]
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
template.messages = [] template.messages = []
for idx, mess in enumerate(messages): for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]: if mess["from"] != template.roles[idx % 2]:
@ -148,11 +152,14 @@ def tokenize_prompt(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.messages = [] template.messages = []
if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
for idx, mess in enumerate(messages): for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]: if mess["from"] != template.roles[idx % 2]:
raise ValueError( raise ValueError(
f"Message should iterate between user and assistant and starts with a \ f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
line from the user. Got the following data:\n{messages}"
) )
template.append_message(mess["from"], mess["content"]) template.append_message(mess["from"], mess["content"])
@ -162,7 +169,7 @@ def tokenize_prompt(
template.messages = template.messages[:-1] template.messages = template.messages[:-1]
# Prepare data # Prepare data
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True) prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None: if tokenizer.bos_token_id is not None:
@ -225,6 +232,10 @@ def tokenize_rlhf(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.clear() template.clear()
if context[0]["from"] == "system":
template.system_message = str(context[0]["content"])
context.pop(0)
for idx, mess in enumerate(context): for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]: if mess["from"] != template.roles[idx % 2]:
raise ValueError( raise ValueError(
@ -345,6 +356,10 @@ def tokenize_kto(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.clear() template.clear()
if prompt[0]["from"] == "system":
template.system_message = str(prompt[0]["content"])
prompt.pop(0)
if prompt[0].get("from", None) != "user": if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user") raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant": if completion.get("from", None) != "assistant":

16
applications/ColossalChat/coati/models/loss.py

@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
skip = False skip = False
ratio_ = ((log_probs - old_log_probs) * action_mask).exp() if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1. # note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold: if ratio_.mean() > self.skip_threshold:
@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, action_mask) if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
loss = loss.mean() loss = loss.mean()
return loss, skip, ratio_.max() return loss, skip, ratio_.max()
@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2 surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2 surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2) / torch.sum(action_mask) if action_mask is not None:
loss = torch.sum(loss * action_mask) loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
else:
loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss return 0.5 * loss

7
applications/ColossalChat/coati/models/utils.py

@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module):
Returns: Returns:
None None
""" """
for module in model.modules(): if model is not None:
if isinstance(module, torch.nn.Dropout): for module in model.modules():
module.p = 0.0 if isinstance(module, torch.nn.Dropout):
module.p = 0.0

9
applications/ColossalChat/coati/trainer/dpo.py

@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
beta: float = 0.1, beta: float = 0.1,
gamma: float = 0.0, gamma: float = 0.0,
length_normalization: bool = False, length_normalization: bool = False,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
start_epoch: int = 0, start_epoch: int = 0,
save_interval: int = 0, save_interval: int = 0,
@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma) self.actor_loss_fn = DpoLoss(beta, gamma)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval self.save_interval = save_interval
self.coordinator = coordinator self.coordinator = coordinator
self.save_dir = save_dir self.save_dir = save_dir
@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model( actor_all_logits = self.model(
@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]

37
applications/ColossalChat/coati/trainer/kto.py

@ -6,7 +6,7 @@ import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
import torch.distributed import torch.distributed as dist
from coati.models.loss import KTOLoss from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer):
beta: float = 0.1, beta: float = 0.1,
desirable_weight: float = 1.0, desirable_weight: float = 1.0,
undesirable_weight: float = 1.0, undesirable_weight: float = 1.0,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
start_epoch: int = 0, start_epoch: int = 0,
save_interval: int = 0, save_interval: int = 0,
@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight) self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval self.save_interval = save_interval
self.coordinator = coordinator self.coordinator = coordinator
self.save_dir = save_dir self.save_dir = save_dir
@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"], batch["kl_attention_mask"],
batch["kl_loss_mask"], batch["kl_loss_mask"],
) )
if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)
batch_size = input_ids.size()[0] batch_size = input_ids.size()[0]
# actor logits # actor logits
@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer):
# sync # sync
loss_mean = all_reduce_mean(tensor=loss) loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean()) chosen_reward_mean = chosen_rewards.mean()
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean()) chosen_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
rejected_reward_mean = rejected_rewards.mean()
rejected_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
chosen_rewards_mean = (
torch.stack(chosen_rewards_list).mean()
if len(chosen_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
rejected_rewards_mean = (
torch.stack(rejected_rewards_list).mean()
if len(rejected_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"], batch["kl_attention_mask"],
batch["kl_loss_mask"], batch["kl_loss_mask"],
) )
if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)
batch_size = input_ids.size()[0] batch_size = input_ids.size()[0]
# actor logits # actor logits

12
applications/ColossalChat/coati/trainer/orpo.py

@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
lam: float = 0.1, lam: float = 0.1,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
start_epoch: int = 0, start_epoch: int = 0,
save_interval: int = 0, save_interval: int = 0,
@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
self.save_dir = save_dir self.save_dir = save_dir
self.num_train_step = 0 self.num_train_step = 0
self.lam = lam self.lam = lam
self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
self.device = get_current_device() self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter() self.accumulative_meter = AccumulativeMeanMeter()
@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"], batch["reject_attention_mask"],
batch["reject_loss_mask"], batch["reject_loss_mask"],
) )
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),

12
applications/ColossalChat/coati/trainer/ppo.py

@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
sample_buffer: bool = False, sample_buffer: bool = False,
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
offload_inference_models: bool = True, offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1, accumulation_steps: int = 1,
save_interval: int = 0, save_interval: int = 0,
save_dir: str = None, save_dir: str = None,
@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
self.actor_optim = actor_optim self.actor_optim = actor_optim
self.critic_optim = critic_optim self.critic_optim = critic_optim
self.save_interval = save_interval self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor") self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic") self.critic_save_dir = os.path.join(save_dir, "critic")
@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn( actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
) )
actor_loss = (1 - self.ptx_coef) * actor_loss actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip: if not to_skip:
@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
input_ids=experience.sequences, attention_mask=experience.attention_mask input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length] ) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn( critic_loss = self.critic_loss_fn(
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask values[:, -num_actions:],
experience.values,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
) )
critic_loss = critic_loss * self.vf_coef critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim) self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)

14
applications/ColossalChat/coati/trainer/sft.py

@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
lr_scheduler: _LRScheduler, lr_scheduler: _LRScheduler,
max_epochs: int = 2, max_epochs: int = 2,
accumulation_steps: int = 8, accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0, start_epoch=0,
save_interval: int = None, save_interval: int = None,
save_dir: str = None, save_dir: str = None,
@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
self.coordinator = coordinator self.coordinator = coordinator
self.num_train_step = 0 self.num_train_step = 0
self.num_eval_step = 0 self.num_eval_step = 0
self.apply_loss_mask = apply_loss_mask
self.accumulative_meter = AccumulativeMeanMeter() self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit( def _before_fit(
@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
for i, batch in enumerate(self.train_dataloader): for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0) batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer) self.booster.backward(loss=loss, optimizer=self.optimizer)
@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
) )
for batch in self.eval_dataloader: for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss_mean = all_reduce_mean(tensor=outputs.loss) loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update() step_bar.update()

29
applications/ColossalChat/examples/README.md

@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- save_dir: path to store the model checkpoints. - save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncated to max_length before feeding to the model. - max_length: input will be padded/truncated to max_length before feeding to the model.
- max_epochs: number of epochs to train. - max_epochs: number of epochs to train.
- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
- batch_size: training batch size. - batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility. - mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes. - save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
@ -461,26 +462,24 @@ Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of
#### Step 1: Data Collection #### Step 1: Data Collection
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format. The first step in Stage 1 is to collect a dataset of human demonstrations of the following JSONL format.
```json ```json
[ {"messages":
{"messages": [
[ {
{ "from": "user",
"from": "user", "content": "what are some pranks with a pen i can do?"
"content": "what are some pranks with a pen i can do?" },
}, {
{ "from": "assistant",
"from": "assistant", "content": "Are you looking for practical joke ideas?"
"content": "Are you looking for practical joke ideas?"
},
...
]
}, },
... ...
] ]
},
...
``` ```

5
applications/ColossalChat/examples/inference/inference.py

@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
tuple: A tuple containing the loaded model and tokenizer. tuple: A tuple containing the loaded model and tokenizer.
""" """
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
model.to(device) model.to(device)
@ -151,7 +151,6 @@ def main(args):
chat_io.prompt_for_output("assistant") chat_io.prompt_for_output("assistant")
prompt = conv.get_prompt(add_generation_prompt=True) prompt = conv.get_prompt(add_generation_prompt=True)
print(prompt + "<end_of_prompt>")
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to( input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
torch.cuda.current_device() torch.cuda.current_device()
) )

8
applications/ColossalChat/examples/training_scripts/train_dpo.py

@ -278,6 +278,10 @@ def train(args):
beta=args.beta, beta=args.beta,
gamma=args.gamma, gamma=args.gamma,
length_normalization=args.length_normalization, length_normalization=args.length_normalization,
<<<<<<< HEAD
=======
apply_loss_mask=not args.disable_loss_mask,
>>>>>>> main
) )
trainer.fit( trainer.fit(
@ -346,6 +350,10 @@ if __name__ == "__main__":
default=False, default=False,
help="Disable the reference model (enabled by default)", help="Disable the reference model (enabled by default)",
) )
<<<<<<< HEAD
=======
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
>>>>>>> main
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")

2
applications/ColossalChat/examples/training_scripts/train_kto.py

@ -297,6 +297,7 @@ def train(args):
beta=args.beta, beta=args.beta,
desirable_weight=args.desirable_weight, desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight, undesirable_weight=args.undesirable_weight,
apply_loss_mask=not args.disable_loss_mask,
) )
trainer.fit( trainer.fit(
@ -341,6 +342,7 @@ if __name__ == "__main__":
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss") parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss") parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss") parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true") parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

2
applications/ColossalChat/examples/training_scripts/train_orpo.py

@ -259,6 +259,7 @@ def train(args):
save_dir=args.save_dir, save_dir=args.save_dir,
coordinator=coordinator, coordinator=coordinator,
lam=args.lam, lam=args.lam,
apply_loss_mask=not args.disable_loss_mask,
) )
trainer.fit( trainer.fit(
@ -301,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, default=1) parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1) parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss") parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true") parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

4
applications/ColossalChat/examples/training_scripts/train_ppo.py

@ -411,6 +411,7 @@ def train(args):
use_cache=True, use_cache=True,
do_sample=True, do_sample=True,
temperature=0.7, temperature=0.7,
apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
save_dir=args.save_path, save_dir=args.save_path,
save_interval=args.save_interval, save_interval=args.save_interval,
@ -498,9 +499,10 @@ if __name__ == "__main__":
parser.add_argument("--critic_lr", type=float, default=9e-6) parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.0) parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256) parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true")

2
applications/ColossalChat/examples/training_scripts/train_sft.py

@ -272,6 +272,7 @@ def train(args):
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
apply_loss_mask=not args.disable_loss_mask,
start_epoch=start_epoch, start_epoch=start_epoch,
save_interval=args.save_interval, save_interval=args.save_interval,
save_dir=args.save_path, save_dir=args.save_path,
@ -317,6 +318,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1) parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1) parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true") parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

2
applications/ColossalChat/requirements.txt

@ -2,7 +2,7 @@ transformers==4.39.3
tqdm tqdm
datasets==2.14.7 datasets==2.14.7
loralib loralib
colossalai==0.4.0 colossalai>=0.4.0
torch>=2.1.0 torch>=2.1.0
langchain langchain
tokenizers tokenizers

36
applications/ColossalChat/tests/test_train.sh

@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
} }
set_n_least_used_CUDA_VISIBLE_DEVICES 4 set_n_least_used_CUDA_VISIBLE_DEVICES 2
set -xu set -xu
@ -119,11 +119,11 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config="" lora_config=""
fi fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "tp_zero2" ]]; then if [[ $plugin == "tp_zero2" ]]; then
tp='4' tp='2'
bs='8' bs='8'
zero_stage='2' zero_stage='2'
plugin='3d' plugin='3d'
@ -136,13 +136,13 @@ for lora_rank in ${LORA_RANK[@]}; do
fi fi
if [[ $plugin == "pp" ]]; then if [[ $plugin == "pp" ]]; then
bs='8' bs='8'
pp='4' pp='2'
plugin='3d' plugin='3d'
fi fi
if [[ $plugin == "sp_split_gather" ]]; then if [[ $plugin == "sp_split_gather" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='split_gather' sp_mode='split_gather'
tp='4' tp='2'
sp='1' sp='1'
bs='8' bs='8'
plugin='3d' plugin='3d'
@ -150,7 +150,7 @@ for lora_rank in ${LORA_RANK[@]}; do
if [[ $plugin == "sp_ring" ]]; then if [[ $plugin == "sp_ring" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='ring' sp_mode='ring'
tp='4' tp='2'
sp='1' sp='1'
bs='8' bs='8'
plugin='3d' plugin='3d'
@ -159,7 +159,7 @@ for lora_rank in ${LORA_RANK[@]}; do
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='all_to_all' sp_mode='all_to_all'
tp='1' tp='1'
sp='4' sp='2'
bs='8' bs='8'
plugin='3d' plugin='3d'
fi fi
@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -242,7 +242,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config="" lora_config=""
fi fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
grad_accu='2' grad_accu='2'
@ -256,7 +256,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -325,7 +325,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config="" lora_config=""
fi fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='16' bs='16'
ebs='32' ebs='32'
fi fi
@ -350,7 +350,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \ --pretrain $pretrain \
--rm_pretrain $pretrain \ --rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
@ -417,7 +417,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then if [[ $plugin == "zero2" ]]; then
@ -442,7 +442,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -500,7 +500,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then if [[ $plugin == "zero2" ]]; then
@ -525,7 +525,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -583,7 +583,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then if [[ $plugin == "zero2" ]]; then
@ -608,7 +608,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \

2
applications/README.md

@ -14,9 +14,9 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include: The list of applications include:
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models - [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
- [X] [ColossalChat](./ColossalChat/): Replication of ChatGPT with RLHF.
- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3. - [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3.
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. - [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. - [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
- [X] [ColossalQA](./ColossalQA/README.md): Document Retrieval Conversation System - [X] [ColossalQA](./ColossalQA/README.md): Document Retrieval Conversation System
- [X] [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Breaks the Length Limit of LLM Inference for Multi-Round Conversations - [X] [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Breaks the Length Limit of LLM Inference for Multi-Round Conversations

18
colossalai/booster/booster.py

@ -1,4 +1,3 @@
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Union
@ -8,6 +7,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.logging import get_dist_logger
SUPPORT_PEFT = False SUPPORT_PEFT = False
try: try:
import peft import peft
@ -81,12 +82,15 @@ class Booster:
plugin, Plugin plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin self.plugin = plugin
self.logger = get_dist_logger()
# set accelerator # set accelerator
if self.plugin and self.plugin.control_device(): if self.plugin and self.plugin.control_device():
self.accelerator = None self.accelerator = None
if device is not None: if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") self.logger.warning(
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
)
else: else:
device = device or "cuda" device = device or "cuda"
self.accelerator = Accelerator(device) self.accelerator = Accelerator(device)
@ -94,7 +98,10 @@ class Booster:
# set precision # set precision
if self.plugin and self.plugin.control_precision(): if self.plugin and self.plugin.control_precision():
if mixed_precision is not None: if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.logger.warning(
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
ranks=[0],
)
self.mixed_precision = None self.mixed_precision = None
elif mixed_precision is None: elif mixed_precision is None:
self.mixed_precision = None self.mixed_precision = None
@ -267,8 +274,9 @@ class Booster:
), "Please provide pretrained directory path if not passing in lora configuration." ), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True: if quantize is True:
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
warnings.warn( self.logger.warning(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
ranks=[0],
) )
else: else:
bnb_quantization_config = BnbQuantizationConfig( bnb_quantization_config = BnbQuantizationConfig(

26
colossalai/booster/plugin/gemini_plugin.py

@ -1,5 +1,4 @@
import gc import gc
import logging
import os import os
import random import random
from pathlib import Path from pathlib import Path
@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
""" """
@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before saving!" assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path) save_config_file(model.unwrap(), checkpoint_path)
logging.info( self.logger.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_model( def load_sharded_model(
@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file): if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file") self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])
assert isinstance(optimizer, GeminiOptimizer) assert isinstance(optimizer, GeminiOptimizer)
@ -371,9 +374,12 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu": if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
self.logger = get_dist_logger()
if enable_async_reduce and not pin_memory: if enable_async_reduce and not pin_memory:
logging.warning( self.logger.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
ranks=[0],
) )
pin_memory = True pin_memory = True
self.gemini_config = dict( self.gemini_config = dict(

81
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1,6 +1,5 @@
import ctypes import ctypes
import random import random
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
@ -27,13 +26,14 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
@ -43,7 +43,7 @@ from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_hand
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@ -74,7 +74,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.dp_group = dp_group self.dp_group = dp_group
self.tp_group = tp_group self.tp_group = tp_group
self.sp_group = sp_group self.sp_group = sp_group
self.use_dpp = use_ddp self.use_ddp = use_ddp
self.require_grad_sync = True self.require_grad_sync = True
self.overlap_allgather = overlap_allgather self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
@ -116,11 +116,10 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
super().__init__(module) super().__init__(module)
self.op_hooks = [] self.op_hooks = []
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8: if use_fp8:
self.op_hooks.append(FP8Hook()) self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8: if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters(): for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter: if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter p.__class__ = ColoParameter
@ -146,8 +145,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
# Disable automatic gradient synchronization. # Disable automatic gradient synchronization.
self.require_grad_sync = False self.require_grad_sync = False
try: try:
if self.use_dpp: if self.use_ddp:
# If using data parallel processing (use_dpp), disable synchronization too. # If using data parallel processing (use_ddp), disable synchronization too.
with self.module.no_sync(): with self.module.no_sync():
yield yield
else: else:
@ -195,7 +194,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
""" """
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all": if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return return
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
@ -980,8 +979,11 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
""" """
def __init__( def __init__(
@ -1031,8 +1033,10 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_allgather: bool = False, overlap_allgather: bool = False,
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False, use_fp8: bool = False,
inner_ring_size: int = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = get_dist_logger()
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
@ -1050,14 +1054,17 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_size > 1 tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1: if sp_size != 1:
warnings.warn( self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
ranks=[0],
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]: elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
self.sp_size = 1 if sp_size is None else sp_size self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
if self.sequence_parallelism_mode == "ring_attn":
enable_flash_attention = True
else: else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert ( assert (
@ -1079,9 +1086,21 @@ class HybridParallelPlugin(PipelinePluginBase):
if dp_outside: if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else: else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) if sequence_parallelism_mode == "ring_attn":
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
@ -1125,6 +1144,13 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
if not parallel_output:
self.logger.warning(
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
ranks=[0],
)
parallel_output = True
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@ -1150,6 +1176,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
@ -1229,20 +1256,23 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False zero_config["partition_grad"] = False
zero_stage = 0 zero_stage = 0
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# Shouldn't use pp (frequent grad accumulation) with torch ddp
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 self.dp_size == 1 and self.pp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
) )
# sync gradients across DP * SP ranks # sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": # Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else: else:
dp_group = self.dp_group dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(
@ -1286,9 +1316,10 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
is_zero = self.dp_size > 1 is_zero = self.dp_size > 1
if self.dp_size == 1: if self.dp_size == 1:
warnings.warn( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
) )
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@ -1331,7 +1362,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
if return_outputs: if return_outputs:
warnings.warn("return_outputs may lead to significant extra memory consumption.") self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])
# Create a context for gradient synchronization based on the optimizer type. # Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
@ -1446,7 +1477,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1 assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config) model = quantize_model(model, bnb_quantization_config)

26
colossalai/booster/plugin/low_level_zero_plugin.py

@ -1,7 +1,5 @@
import enum import enum
import logging
import os import os
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook from colossalai.quantization.fp8_hook import FP8Hook
@ -146,7 +145,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
""" """
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -183,10 +182,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
@ -273,7 +273,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
from peft import PeftModel from peft import PeftModel
@ -371,8 +371,8 @@ class LowLevelZeroPlugin(DPPluginBase):
) )
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
self.logger = get_dist_logger()
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
# set class name with stage, for better error message # set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -408,7 +408,7 @@ class LowLevelZeroPlugin(DPPluginBase):
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config) model = quantize_model(model, bnb_quantization_config)
@ -457,8 +457,9 @@ class LowLevelZeroPlugin(DPPluginBase):
origin_param = name2param[origin_key] origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn( self.logger.warning(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
ranks=[0],
) )
elif ( elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
@ -501,7 +502,10 @@ class LowLevelZeroPlugin(DPPluginBase):
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False zero_optim_kwargs["partition_grad"] = False
zero_stage = 0 zero_stage = 0

25
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -1,4 +1,3 @@
import warnings
from collections import defaultdict from collections import defaultdict
from types import MethodType from types import MethodType
from typing import Callable, List, Optional, OrderedDict, Tuple from typing import Callable, List, Optional, OrderedDict, Tuple
@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
@ -217,12 +217,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False, use_fp8: bool = False,
) -> None: ) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2: if overlap_communication or zero_stage == 2:
overlap_communication = False overlap_communication = False
zero_stage = 1 zero_stage = 1
warnings.warn( self.logger.warning(
f"overlap_communication and zero_stage are set to False and 1 because " f"overlap_communication and zero_stage are set to False and 1 because "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.",
ranks=[0],
) )
assert ( assert (
@ -240,8 +242,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_size > 1 tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1: if sp_size != 1:
warnings.warn( self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode},"
"will ignore the given sequence parallelism size.",
ranks=[0],
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -326,6 +330,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
else: else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group, sequence_parallel_process_group=self.sp_group,
@ -403,8 +408,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.sequence_parallelism_mode == "all_to_all" and self.sequence_parallelism_mode == "all_to_all"
) )
if use_ddp: if use_ddp:
warnings.warn( self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated" f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
@ -461,9 +467,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
else: else:
if self.dp_size <= 1: if self.dp_size <= 1:
warnings.warn( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
) )
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer( optimizer = MoeHybridParallelZeroOptimizer(

2
colossalai/booster/plugin/torch_ddp_plugin.py

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
""" """

15
colossalai/booster/plugin/torch_fsdp_plugin.py

@ -1,6 +1,4 @@
import logging
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
@ -30,6 +28,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
utils.save_config_file(model.unwrap(), checkpoint_path) utils.save_config_file(model.unwrap(), checkpoint_path)
logging.info( self.logger.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
@ -313,6 +313,7 @@ class TorchFSDPPlugin(DPPluginBase):
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
) )
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.logger = get_dist_logger()
else: else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -364,7 +365,7 @@ class TorchFSDPPlugin(DPPluginBase):
if optimizer is not None: if optimizer is not None:
if len(optimizer.param_groups) > 1: if len(optimizer.param_groups) > 1:
warnings.warn( self.logger.warning(
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
) )
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)

4
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

@ -203,7 +203,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model. # Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model. # So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0: if self.dp_rank != 0:
@ -643,14 +642,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather() model._force_wait_all_gather()
model = model.unwrap() model = model.unwrap()
if self.dp_rank != 0: if self.dp_rank != 0:
return return
# The logic of collecting parameter shards along tp degree # The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict() state_dict = model.state_dict()
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict. # When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0: if self.tp_rank == 0:
@ -660,7 +657,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_list = [None for _ in range(self.pp_size)] state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group) dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group) dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
# Only the master rank do the saving. # Only the master rank do the saving.
if self.coordinator.is_master(): if self.coordinator.is_master():
complete_state_dict = dict() complete_state_dict = dict()

4
colossalai/lazy/pretrained.py

@ -62,7 +62,6 @@ def new_from_pretrained(
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
@ -116,7 +115,6 @@ def new_from_pretrained(
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
@ -195,7 +193,6 @@ def new_from_pretrained(
"cache_dir": cache_dir, "cache_dir": cache_dir,
"force_download": force_download, "force_download": force_download,
"proxies": proxies, "proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "use_auth_token": use_auth_token,
"user_agent": user_agent, "user_agent": user_agent,
@ -312,7 +309,6 @@ def new_from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,

2
colossalai/legacy/moe/openmoe/model/openmoe_policy.py

@ -171,7 +171,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
# TODO: recursively assign ep group foe all modules # TODO: recursively assign ep group foe all modules
new_item = { new_item = {
OpenMoeForCausalLM: ModulePolicyDescription( OpenMoeForCausalLM: ModulePolicyDescription(

3
colossalai/legacy/nn/layer/parallel_1d/_operation.py

@ -81,6 +81,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
# TODO: This seems to only work if you add torch.cuda.Event.wait()
# _ = torch.zeros(1, device=grad_output.device)
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None

5
colossalai/logging/logger.py

@ -64,7 +64,10 @@ class DistributedLogger:
self._logger.propagate = False self._logger.propagate = False
DistributedLogger.__instances[name] = self DistributedLogger.__instances[name] = self
self.rank = dist.get_rank() if dist.is_initialized() else 0
@property
def rank(self):
return dist.get_rank() if dist.is_initialized() else 0
@staticmethod @staticmethod
def __get_call_info(): def __get_call_info():

1
colossalai/pipeline/schedule/interleaved_pp.py

@ -306,7 +306,6 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc. # for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used # Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList): if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)

1
colossalai/pipeline/schedule/one_f_one_b.py

@ -271,6 +271,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = model_forward(model, micro_batch, input_obj) output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage(): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
if outputs is not None: if outputs is not None:

4
colossalai/shardformer/layer/__init__.py

@ -1,5 +1,5 @@
from ._operation import all_to_all_comm from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
@ -31,5 +31,7 @@ __all__ = [
"VocabParallelLMHead1D", "VocabParallelLMHead1D",
"AttnMaskType", "AttnMaskType",
"ColoAttention", "ColoAttention",
"RingAttention",
"get_pad_info",
"all_to_all_comm", "all_to_all_comm",
] ]

36
colossalai/shardformer/layer/_operation.py

@ -2,6 +2,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from .utils import is_share_sp_tp
try: try:
import fused_mix_prec_layer_norm_cuda import fused_mix_prec_layer_norm_cuda
except: except:
@ -105,7 +107,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
elif ctx.async_grad_allreduce: elif ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output) grad_weight = total_input.t().matmul(grad_output)
@ -353,7 +355,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous() ).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None: if _grad_accum_fusion_available and weight.grad is not None:
@ -677,8 +679,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous() ).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output) grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
@ -760,16 +762,20 @@ class _ReduceForward(torch.autograd.Function):
Args: Args:
input_: input matrix. input_: input matrix.
parallel_mode: parallel mode. process_group: communication group.
""" """
@staticmethod @staticmethod
def forward(ctx, input_, process_group, fp8_communication=False): def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None, None if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None, None
class _ReduceBackward(torch.autograd.Function): class _ReduceBackward(torch.autograd.Function):
@ -1079,8 +1085,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def reduce_forward(input_, process_group, fp8_communication=False): def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, fp8_communication) return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
def reduce_backward(input_, process_group, fp8_communication=False): def reduce_backward(input_, process_group, fp8_communication=False):
@ -1089,3 +1095,15 @@ def reduce_backward(input_, process_group, fp8_communication=False):
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
"""
Gather the output of the last layer for cross entropy computation
"""
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
)
return hidden_states

916
colossalai/shardformer/layer/attn.py

File diff suppressed because it is too large Load Diff

17
colossalai/shardformer/layer/linear.py

@ -202,19 +202,21 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode is None: if self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward( input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
) )
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring": elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward( output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
) )
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
@ -442,6 +444,9 @@ class Linear1D_Row(ParallelModule):
dim=self.seq_parallel_dim, dim=self.seq_parallel_dim,
ring=True, ring=True,
) )
else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:

167
colossalai/shardformer/layer/loss.py

@ -4,10 +4,15 @@ from torch.autograd import Function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from .utils import is_share_sp_tp
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] __all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
_IGNORE_IDX = -100
class DistCrossEntropy(Function): class DistCrossEntropy(Function):
r""" r"""
@ -26,11 +31,12 @@ class DistCrossEntropy(Function):
process_group: ProcessGroup, process_group: ProcessGroup,
vocab_size: int, vocab_size: int,
dtype=torch.float32, dtype=torch.float32,
mode="mean",
): ):
r""" r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows: Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i])) loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewrite as: and can be rewriten as:
loss = log(sum(exp(x[i])) - x[class] loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
@ -44,12 +50,10 @@ class DistCrossEntropy(Function):
Returns: Returns:
:class:`torch.Tensor`: The cross entropy loss :class:`torch.Tensor`: The cross entropy loss
""" """
assert mode in ["mean", "sum"]
# get the max # get the max
logits_max = torch.max(vocab_logits, dim=-1)[0] logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device # mask the target in the local device
rank = dist.get_rank(group=process_group) rank = dist.get_rank(group=process_group)
@ -70,24 +74,25 @@ class DistCrossEntropy(Function):
mask = (target < down_threshold) | (target >= up_threshold) mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold masked_target = target.clone() - down_threshold
masked_target[mask] = 0 masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
# minus the max to avoid the result of sum of exp is too large and the log is nan
handle.wait()
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# reshape the logits and target # reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len] # reshape the labels to [bath_size * seq_len]
self_vocab_size = vocab_logits.size()[-1] self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size) logits_2d = vocab_logits.view(-1, self_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero # extract the x[class] and set the x[other device] to zero
pred_logits_1d = logits_2d[ idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous()
]
pred_logits_1d = pred_logits_1d.clone().contiguous()
pred_logits = pred_logits_1d.view_as(target) pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0 pred_logits[mask] = 0.0
# allreduce the get all x(i,y) # all-reduce to get full x[i, y]
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True)
exp_logits = vocab_logits exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits) torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
@ -95,23 +100,29 @@ class DistCrossEntropy(Function):
# calculate the loss # calculate the loss
# loss = log(sum(exp(x[i]))) - x[class] # loss = log(sum(exp(x[i]))) - x[class]
handle.wait()
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
num_non_zero = torch.sum(loss != 0.0) if mode == "mean":
ctx.inv_num_non_zero = 1.0 / num_non_zero num_non_zero = torch.sum(loss != 0.0)
loss = torch.sum(loss).div_(num_non_zero) ctx.inv_num_non_zero = 1.0 / num_non_zero
loss = torch.sum(loss).div_(num_non_zero)
else:
loss = torch.sum(loss)
# calculate the softmax # calculate the softmax
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0 exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype ctx.dtype = dtype
ctx.mode = mode
return loss return loss
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# retrieve the saved tensors # retrieve the saved tensors
grad_output = grad_output * ctx.inv_num_non_zero if ctx.mode == "mean":
grad_output = grad_output * ctx.inv_num_non_zero
exp_logits, mask, masked_target_1d = ctx.saved_tensors exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad # use exp logits as the input grad
@ -123,55 +134,113 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1)) grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None, None return grad_logits, None, None, None, None, None, None
def cross_entropy_1d( def cross_entropy_1d(
vocab_logits: torch.Tensor, vocab_logits: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
ignore_index: int = -100, ignore_index: int = _IGNORE_IDX,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
vocab_size: int = None, vocab_size: int = None,
dtype: torch.dtype = None, dtype: torch.dtype = None,
mode: str = "mean",
) -> torch.Tensor: ) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
def dist_cross_entropy( def dist_cross_entropy(
labels: torch.Tensor, labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig, shard_config: ShardConfig,
out_features: int, out_features: int,
vocab_size: int, vocab_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Helper to compute cross entropy loss for most shardformer models, Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP.
compatible with PP, TP and SP.
""" """
if labels is not None: # Split labels if not gather output
# Shift so that tokens < n predict n sp_group = shard_config.sequence_parallel_process_group
shift_logits = logits[..., :-1, :].contiguous() sp_rank = dist.get_rank(sp_group)
shift_labels = labels[..., 1:].contiguous() sp_size = shard_config.sequence_parallel_size
# Flatten the tokens sp_mode = shard_config.sequence_parallelism_mode
loss_fct = CrossEntropyLoss() parallel_output = shard_config.parallel_output
shift_labels = shift_labels.view(-1) is_tp = shard_config.enable_tensor_parallelism
shift_labels = shift_labels.to(shift_logits.device) is_packed = labels.dim() == 2
if shard_config.enable_tensor_parallelism and shard_config.parallel_output: if is_packed:
# Cross entropy with all-reduce for TP bs, seq_len = labels.shape
new_vocab_size = logits.shape[-1] else:
shift_logits = shift_logits.view(-1, new_vocab_size) # padded sequence
loss = cross_entropy_1d( seq_len = labels.shape[-1]
shift_logits, logits = logits.reshape(-1, *logits.shape[2:])
shift_labels, seq_dim = 0
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features, # Shift labels to predict the next token, and remove the tail logit predicting <EOS>
dtype=dtype, is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode))
) split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward
else:
# NOTE if use TP and not parallel_output, the output is gathered. if sp_mode == "ring_attn":
# see VocabParallelLMHead1D # For Zigzag Ring Attention, labels should've been split and
shift_logits = shift_logits.view(-1, vocab_size) # shifted by RingAttention.prepare_varlen_batch()
loss = loss_fct(shift_logits, shift_labels) if sp_rank == 0:
logits = logits[..., :-1, :]
return loss logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
elif is_sp:
# Shift only once: either before splitting or in the last rank without splitting
if split_labels_here or (sp_rank == sp_size - 1):
labels = labels[..., 1:]
if split_labels_here:
labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank]
if sp_rank == sp_size - 1:
logits = logits[..., :-1, :]
# Pad logits and labels to the same shape across all ranks for TP all_reduce
if is_tp and parallel_output:
# If is packed sequence (label dim is 1), then each seq already has the end label token padded.
# torch.cat is faster than F.pad...
pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:])
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device)
logits = torch.cat([logits, padding], dim=seq_dim)
pad_shape = (labels.shape[0], 1) if is_packed else (1,)
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device)
labels = torch.cat([labels, padding], dim=seq_dim)
else:
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
num_nonzero = (labels != _IGNORE_IDX).sum()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum")
labels = labels.view(-1)
if is_tp and parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
logits = logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
mode="sum",
)
else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size)
loss = loss_fct(logits, labels)
# Reduce loss instead of gathering logits over seq dim for savings
if split_labels_here or sp_mode == "ring_attn":
# Get the global non-zero count
loss = torch.stack((loss, num_nonzero))
# Rescale to offset the grad / (DP * SP) in HybridParallelPlugin
loss = reduce_forward(loss, sp_group, grad_scale=sp_size)
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss

25
colossalai/shardformer/layer/normalization.py

@ -42,7 +42,7 @@ try:
return output return output
except ImportError: except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm):
Returns: Returns:
nn.Module: FusedRMSNorm module. nn.Module: FusedRMSNorm module.
""" """
try:
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True) elementwise_affine = getattr(module, "elementwise_affine", True)
rmsnorm = FusedRMSNormWithHook( try:
normalized_shape=normalized_shape, rmsnorm = FusedRMSNormWithHook(
eps=eps, normalized_shape=normalized_shape,
elementwise_affine=elementwise_affine, eps=eps,
) elementwise_affine=elementwise_affine,
)
except ImportError:
warnings.warn(
"Module replacement failed.\
Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
return module
rmsnorm.weight = module.weight rmsnorm.weight = module.weight

2
colossalai/shardformer/layer/qkv_fused_linear.py

@ -555,7 +555,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
else: else:
if self.seq_parallel_mode is None: if self.seq_parallel_mode is None:
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, self.fp8_communication) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather": elif self.seq_parallel_mode == "split_gather":
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(

198
colossalai/shardformer/layer/utils.py

@ -1,5 +1,5 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List from typing import List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -289,3 +289,199 @@ def create_randomizer_with_offset(
Randomizer.increment_index() Randomizer.increment_index()
return Randomizer(seed=base_seed) return Randomizer(seed=base_seed)
def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
Args:
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
sp_group (ProcessGroup): The process group for sequence parallelism.
seq_dim (int): The sequence dimension to split.
is_label (bool): If True, mask and shift the tensor for next token prediction.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if isinstance(batch, torch.Tensor):
batch = [batch]
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
if sp_size > 1:
for idx, tensor in enumerate(batch):
assert (
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
if is_label:
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
tensor = tensor.view(
*tensor.shape[:seq_dim],
2 * sp_size,
tensor.shape[seq_dim] // (2 * sp_size),
*tensor.shape[seq_dim + 1 :],
)
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
tensor = tensor.index_select(seq_dim, indices).contiguous()
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
if len(batch) == 1:
return batch[0]
return batch
def split_varlen_zigzag(
batch: Union[List[torch.Tensor], torch.Tensor],
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_2d: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False;
else return a padded batch of sequences.
Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
or (B, max_seqlen // sp_size, ...) if is_2d
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if is_2d:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
if is_2d:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
local_seq = torch.zeros(shape, dtype=dtype, device=device)
else:
total_seqlen = cu_seqlens[-1]
assert (
total_seqlen % (2 * sp_size) == 0
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
local_seq = []
for j in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[j], cu_seqlens[j + 1]
seqlen = end - start
assert (
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
seq = seq.chunk(2 * sp_size, dim=0)
half = seqlen // sp_size // 2
local_seq[j][:half] = seq[sp_rank]
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
else:
seq = packed_seq[start:end]
if is_label:
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)
if len(batch) == 1:
batch = batch[0]
return batch
def is_share_sp_tp(sp_mode: str):
"""sp_mode "ring" and "split_gather" use the TP group as SP group
to split both the vocab and sequence, so we must gather the sequence
to correctly get logits at each positions.
"""
return sp_mode in ["ring", "split_gather"]
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = []
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(
self,
send_tensor: torch.Tensor,
recv_tensor: Optional[torch.Tensor] = None,
commit: bool = True,
) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(send_tensor)
else:
res = recv_tensor
# looks like batch_isend_irecv doesn't deadlock even
# when we don't swap send recv ops based on rank
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.extend([send_op, recv_op])
if commit:
self._reqs = dist.batch_isend_irecv(self._ops)
return res
def commit(self):
assert len(self._ops) > 0, "No ops to commit"
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
assert len(self._reqs) > 0, "No requests to wait for"
for req in self._reqs:
req.wait()
self._reqs = []
self._ops = []
@torch.jit.script
def get_half_index(cu_seqlens, *, front: bool):
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
for i in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
if front:
end = (start + end) // 2
else:
start = (start + end) // 2
index[start:end] = True
return index

20
colossalai/shardformer/modeling/chatglm2.py

@ -216,6 +216,13 @@ class ChatGLMPipelineForwards:
grad_scale=1 / shard_config.sequence_parallel_size, grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication, fp8_communication=shard_config.fp8_communication,
) )
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
if output_hidden_states: if output_hidden_states:
@ -257,6 +264,13 @@ class ChatGLMPipelineForwards:
grad_scale=shard_config.sequence_parallel_size, grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication, fp8_communication=shard_config.fp8_communication,
) )
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
@ -405,6 +419,12 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
if sp_mode in ["all_to_all"] and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
)
use_cache = False
if sp_mode in ["all_to_all"] and self.training: if sp_mode in ["all_to_all"] and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(

8
colossalai/shardformer/modeling/command.py

@ -26,6 +26,8 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
class CommandPipelineForwards: class CommandPipelineForwards:
""" """
@ -353,7 +355,7 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -366,7 +368,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None: if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and ( assert (sp_size is not None) and (
sp_group is not None sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel" ), "Must specify sp_size and sp_group for sequence parallel"
@ -465,7 +467,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
return forward return forward
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def forward( def forward(

10
colossalai/shardformer/modeling/deepseek.py

@ -145,7 +145,11 @@ class EPDeepseekMoE(nn.Module):
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) dist.all_to_all_single(
output_split_sizes,
input_split_sizes,
group=self.ep_group,
)
with torch.no_grad(): with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
@ -695,6 +699,10 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
self._use_flash_attention_2 = shard_config.enable_flash_attention
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
if self._use_flash_attention_2: if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None

145
colossalai/shardformer/modeling/llama.py

@ -1,8 +1,9 @@
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -24,14 +25,14 @@ from transformers.models.llama.modeling_llama import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer import AttnMaskType
all_to_all_comm, from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
gather_forward_split_backward, from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
class LlamaPipelineForwards: class LlamaPipelineForwards:
@ -57,6 +58,10 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
# Split output only when computing cross entropy using llama_for_causal_lm_forward
# or get_lm_forward_with_dist_cross_entropy
# Default to True to avoid bug when calling classification forward from huggingface
force_sp_output_gather: bool = True,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -97,7 +102,7 @@ class LlamaPipelineForwards:
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size sp_size = shard_config.sequence_parallel_size
if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
# For correct positions ids. The states will be gather along the seq dim in the attention layer later. # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
seq_length *= sp_size seq_length *= sp_size
past_seen_tokens = 0 past_seen_tokens = 0
@ -127,22 +132,36 @@ class LlamaPipelineForwards:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention: if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
elif shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs( attn_kwargs = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
hidden_states.dtype, hidden_states.dtype,
hidden_states.device, hidden_states.device,
q_padding_mask=attention_mask, q_padding_mask=attention_mask,
is_causal=True, is_causal=True,
invert=(sp_mode != "ring_attn"),
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# Support SP + PP # Support SP + PP
# TODO: support padded casual cu_seqlens across stages
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if sp_mode in ["ring", "split_gather"]: # Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids
)
else:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
) )
@ -181,12 +200,11 @@ class LlamaPipelineForwards:
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if idx - start_idx < num_ckpt_layers: if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attn_kwargs,
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -196,14 +214,13 @@ class LlamaPipelineForwards:
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attn_kwargs,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -213,13 +230,9 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_forward_split_backward( hidden_states = gather_sp_output(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
) )
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
@ -306,6 +319,15 @@ class LlamaPipelineForwards:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False output_hidden_states = False
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
# Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
else:
# [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward( outputs = LlamaPipelineForwards.llama_model_forward(
self.model, self.model,
@ -323,6 +345,7 @@ class LlamaPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_output_gather=False,
) )
past_key_values = None past_key_values = None
@ -469,7 +492,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
@ -478,7 +501,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None: if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and ( assert (sp_size is not None) and (
sp_group is not None sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel" ), "Must specify sp_size and sp_group for sequence parallel"
@ -489,7 +512,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring # sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]: if is_share_sp_tp(sp_mode):
q_len *= sp_size q_len *= sp_size
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
@ -534,6 +557,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
) )
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -545,12 +569,21 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if shard_config.enable_flash_attention: if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
key_states,
value_states,
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
)
elif shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else: else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@ -613,6 +646,10 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
# Split output only when computing cross entropy using llama_for_causal_lm_forward
# or get_lm_forward_with_dist_cross_entropy
# Default to True to avoid bug when calling classification forward from huggingface
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -639,32 +676,45 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
past_seen_tokens = 0 past_seen_tokens = 0
seq_len = inputs_embeds.shape[1] seq_len = inputs_embeds.shape[1]
batch_size = inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions) if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.") raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs( attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
inputs_embeds.dtype, inputs_embeds.dtype,
inputs_embeds.device, inputs_embeds.device,
q_padding_mask=attention_mask, q_padding_mask=attention_mask,
is_causal=True, is_causal=True,
invert=(sp_mode != "ring_attn"),
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, inputs_embeds, position_ids
)
else:
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
if sp_mode in ["ring", "split_gather"]: elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward( inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
) )
@ -686,7 +736,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attn_kwargs,
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -697,7 +747,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attn_kwargs,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -714,14 +764,10 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence
if sp_mode == "ring" or sp_mode == "split_gather": if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
hidden_states = gather_forward_split_backward( hidden_states = gather_sp_output(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
) )
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
@ -795,6 +841,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
# Special processing: Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
else:
# [B, max_seq_len // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
@ -807,6 +862,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
force_sp_output_gather=False,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@ -817,7 +873,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
else: else:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
) )

4
colossalai/shardformer/modeling/mixtral.py

@ -696,7 +696,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

1
colossalai/shardformer/policies/base_policy.py

@ -75,6 +75,7 @@ class Policy(ABC):
def __init__(self) -> None: def __init__(self) -> None:
self.shard_config: Optional[ShardConfig] = None self.shard_config: Optional[ShardConfig] = None
self.model: Optional[Module] = None self.model: Optional[Module] = None
self.is_causal = None # Whether we're doing causal lm, i.e. using cross entropy
def set_model(self, model: nn.Module) -> None: def set_model(self, model: nn.Module) -> None:
r""" r"""

31
colossalai/shardformer/policies/command.py

@ -69,13 +69,18 @@ class CommandPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.")
tp_size = self.shard_config.tensor_parallel_size or None
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
decoder_attribute_replacement = { num_q_heads //= sp_size
"num_heads": self.model.config.num_attention_heads // sp_size, decoder_attribute_replacement = {"num_heads": num_q_heads}
} if num_kv_heads:
if getattr(self.model.config, "num_key_value_heads", False): num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -104,21 +109,18 @@ class CommandPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 num_q_heads % tp_size == 0
), f"The number of attention heads must be divisible by tensor parallel size." ), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"): if hasattr(self.model.config, "num_key_value_heads"):
assert ( assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
decoder_attribute_replacement = { decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // tp_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads": num_q_heads // tp_size,
} }
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads // tp_size
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
policy[CohereDecoderLayer] = ModulePolicyDescription( policy[CohereDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -297,10 +299,11 @@ class CommandForCausalLMPolicy(CommandPolicy):
def module_policy(self): def module_policy(self):
from transformers import CohereForCausalLM from transformers import CohereForCausalLM
self.is_causal = True
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
CohereForCausalLM: ModulePolicyDescription( CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

52
colossalai/shardformer/policies/llama.py

@ -69,13 +69,20 @@ class LlamaPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.")
tp_size = self.shard_config.tensor_parallel_size
# Modified by SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
decoder_attribute_replacement = { num_q_heads //= sp_size
"num_heads": self.model.config.num_attention_heads // sp_size, decoder_attribute_replacement = {"num_heads": num_q_heads}
} if num_kv_heads:
if getattr(self.model.config, "num_key_value_heads", False): num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -104,21 +111,20 @@ class LlamaPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 num_q_heads % tp_size == 0
), f"The number of attention heads must be divisible by tensor parallel size." ), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"): if hasattr(self.model.config, "num_key_value_heads"):
assert ( assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
num_q_heads //= tp_size
decoder_attribute_replacement = { decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // tp_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads": num_q_heads,
} }
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( num_kv_heads //= tp_size
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
)
policy[LlamaDecoderLayer] = ModulePolicyDescription( policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -302,10 +308,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
def module_policy(self): def module_policy(self):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
self.is_causal = True
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
@ -321,10 +328,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
], ],
) )
} }
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else: else:
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(
@ -344,7 +347,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
) )
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
# Compute loss distributedly along the sequence dimension
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
return policy return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
@ -384,7 +391,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
LlamaForSequenceClassification: ModulePolicyDescription( LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) suffix="score",
target_module=Linear1D_Col,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
),
) )
] ]
) )

2
colossalai/shardformer/policies/mistral.py

@ -299,7 +299,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
MistralForCausalLM: ModulePolicyDescription( MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

15
colossalai/shardformer/policies/mixtral.py

@ -144,10 +144,14 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=embedding_cls, target_module=embedding_cls,
kwargs={ kwargs=(
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, {
"fp8_communication": self.shard_config.fp8_communication, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
}, "fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
), ),
policy=policy, policy=policy,
target_key=MixtralModel, target_key=MixtralModel,
@ -164,7 +168,6 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group, "ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
}, },
) )
], ],
@ -285,7 +288,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
policy = super().module_policy() policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules # TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
MixtralForCausalLM: ModulePolicyDescription( MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

12
colossalai/shardformer/shard/shard_config.py

@ -10,7 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"] __all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
@dataclass @dataclass
@ -30,6 +30,8 @@ class ShardConfig:
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
For SP: set to True to NOT gather the output along the seq dim.
""" """
tensor_parallel_process_group: Optional[ProcessGroup] = None tensor_parallel_process_group: Optional[ProcessGroup] = None
@ -48,6 +50,8 @@ class ShardConfig:
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
inner_ring_size: Optional[int] = None
# for moe related # for moe related
moe_dp_group: Optional[ProcessGroup] = None moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None
@ -81,9 +85,9 @@ class ShardConfig:
self.enable_tensor_parallelism self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]: elif self.sequence_parallelism_mode in ["all_to_all"]:
assert ( # assert (
not self.enable_tensor_parallelism # not self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap: if self.enable_sequence_overlap:
self.enable_sequence_overlap = False self.enable_sequence_overlap = False
warnings.warn( warnings.warn(

2
colossalai/testing/utils.py

@ -176,7 +176,7 @@ def rerun_if_address_is_in_use():
else: else:
exception = Exception exception = Exception
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*") func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*(A|a)ddress already in use.*")
return func_wrapper return func_wrapper

12
colossalai/zero/gemini/gemini_optimizer.py

@ -1,7 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy import copy
import math import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch import torch
@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose self.verbose = verbose
self.param_groups_backup = list() self.param_groups_backup = list()
self.logger = get_dist_logger()
# Mapping from integer id to real/fake param tensor, used for checkpointing. # Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict()
@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper):
for name, param in module.named_parameters(): for name, param in module.named_parameters():
if is_ddp_ignored(param): if is_ddp_ignored(param):
if param.requires_grad: if param.requires_grad:
warnings.warn( self.logger.warning(
f"Parameter `{name}` is ignored by DDP but requires gradient! " f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!" "You should handle its optimizer update by yourself!",
ranks=[0],
) )
else: else:
ddp_param_list.append(param) ddp_param_list.append(param)
@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper):
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") self.logger.warning(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
class GeminiAdamOptimizer(GeminiOptimizer): class GeminiAdamOptimizer(GeminiOptimizer):

7
docs/source/en/basics/launch_colossalai.md

@ -131,17 +131,18 @@ with one simple command. There are two ways you can launch multi-node jobs.
This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start
multi-node training with the following command. Compared to single-node training, you must specify the `master_addr` multi-node training with the following command. Compared to single-node training, you must specify the `master_addr`
option, which is auto-set to localhost if running on a single node only. option, which is auto-set to localhost if running on a single node only. \
Additionally, you must also ensure that all nodes share the same open ssh port, which can be specified using --ssh-port.
:::caution :::caution
`master_addr` cannot be localhost when running on multiple nodes, it should be the hostname or IP address of a node. `master_addr` cannot be localhost when running on multiple nodes, it should be the **hostname or IP address** of a node.
::: :::
```shell ```shell
# run on these two nodes # run on these two nodes
colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22
``` ```
- Run with `--hostfile` - Run with `--hostfile`

6
docs/source/zh-Hans/basics/launch_colossalai.md

@ -116,17 +116,17 @@ colossalai run --nproc_per_node 4 --master_port 29505 test.py
- 通过`--hosts`来启动 - 通过`--hosts`来启动
这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。 这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。
比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。 比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。同时,你需要确保每个节点都使用同一个ssh port。可以通过--ssh-port设置。
:::caution :::caution
多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的名字或者IP地址。 多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的**名字或者IP地址**
::: :::
```shell ```shell
# 在两个节点上训练 # 在两个节点上训练
colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22
``` ```

33
examples/language/llama/benchmark.py

@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
# Constants # Constants
# ============================== # ==============================
# We have lots of llamas for your choice!
MODEL_CONFIGS = { MODEL_CONFIGS = {
"100m": LlamaConfig( "100m": LlamaConfig(
max_position_embeddings=4096, max_position_embeddings=4096,
@ -36,6 +37,7 @@ MODEL_CONFIGS = {
intermediate_size=2048, intermediate_size=2048,
hidden_size=1024, hidden_size=1024,
), ),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096), "7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig( "13b": LlamaConfig(
hidden_size=5120, hidden_size=5120,
@ -68,9 +70,6 @@ def main():
default="gemini", default="gemini",
help="Choose which plugin to use", help="Choose which plugin to use",
) )
parser.add_argument(
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
)
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
@ -94,13 +93,26 @@ def main():
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False) parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
"--nsys",
action="store_true",
help="Use nsys for profiling. \
You should put something like this before colossalai launch: \
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8", action="store_true") parser.add_argument("--use_fp8", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
default="all_to_all",
choices=["all_to_all", "ring_attn", "ring", "split_gather"],
help="Sequence parallelism mode",
)
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -203,13 +215,12 @@ def main():
num_model_chunks=args.n_chunks, num_model_chunks=args.n_chunks,
zero_stage=args.zero, zero_stage=args.zero,
sp_size=args.sp, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1, enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
@ -303,8 +314,9 @@ def main():
with get_profile_context( with get_profile_context(
args.profile, args.profile,
args.ignore_steps, args.ignore_steps,
len(dataloader) - 1, 1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof: ) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader) data_iter = iter(dataloader)
@ -330,13 +342,16 @@ def main():
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
outputs = model(**batch) outputs = model(**batch)
loss = outputs[0] loss = outputs[0]
del outputs # free memory
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
prof.step() prof.step()
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")

2
examples/language/opt/README.md

@ -17,7 +17,7 @@ limitations under the License.
## OPT ## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.
## Our Modifications ## Our Modifications

24
examples/language/performance_evaluator.py

@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item() return tensor.item()
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False):
class DummyProfiler: class DummyProfiler:
def __init__(self): def __init__(self):
self.step_number = 0 self.step_number = 0
@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
pass pass
class NsysProfiler:
def __init__(self, warmup_steps, active_steps):
self.step_number = 0
self.warmup_steps = warmup_steps
self.active_steps = active_steps
def step(self):
if self.step_number == self.warmup_steps:
torch.cuda.cudart().cudaProfilerStart()
elif self.step_number == self.warmup_steps + self.active_steps:
torch.cuda.cudart().cudaProfilerStop()
self.step_number += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
if enable_flag: if enable_flag:
if nsys:
return NsysProfiler(warmup_steps, active_steps)
return profile( return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),

2
examples/tutorial/opt/opt/README.md

@ -19,7 +19,7 @@ limitations under the License.
## OPT ## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).

8
extensions/pybind/flash_attention/flash_attention_dao_cuda.py

@ -57,14 +57,14 @@ class FlashAttentionDaoCudaExtension(_Extension):
q_indices: Optional[torch.Tensor] = None, q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None,
): ):
# [B, N, S, D] -> [B, S, N, D] # [B, H, S, D] -> [B, S, H, D]
q = q.transpose(1, 2) q = q.transpose(1, 2)
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
b, s_q = q.shape[:2] b, s_q = q.shape[:2]
if cu_seqlens_q is not None: if cu_seqlens_q is not None:
# padded / padded causal # padded / padded causal
# unpad input: [B, S, N, D] -> [T, N, D] # unpad input: [B, S, H, D] -> [T, H, D]
q = _unpad_input(q, q_indices) q = _unpad_input(q, q_indices)
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
attn_output = flash_attn_varlen_kvpacked_func( attn_output = flash_attn_varlen_kvpacked_func(
@ -78,7 +78,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
) )
# pad output: [T, N, D] -> [B, S, N, D] # pad output: [T, H, D] -> [B, S, H, D]
attn_output = pad_input(attn_output, q_indices, b, s_q) attn_output = pad_input(attn_output, q_indices, b, s_q)
else: else:
# causal / no attn mask # causal / no attn mask
@ -90,7 +90,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
) )
# [B, S, N, D] -> [B, N, S, D] # [B, S, H, D] -> [B, H, S, D]
return attn_output.transpose(1, 2) return attn_output.transpose(1, 2)
return flash_attention return flash_attention

2
requirements/requirements-test.txt

@ -9,7 +9,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
torchrec==0.2.0 torchrec==0.2.0
contexttimer contexttimer
einops einops
triton==2.1.0 triton
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece SentencePiece
ninja ninja

2
requirements/requirements.txt

@ -8,7 +8,7 @@ click
fabric fabric
contexttimer contexttimer
ninja ninja
torch>=2.1.0,<=2.3.0 torch>=2.1.0,<=2.4.0
safetensors safetensors
einops einops
pydantic pydantic

4
tests/kit/model_zoo/__init__.py

@ -22,9 +22,9 @@ COMMON_MODELS = [
"transformers_bloom_for_causal_lm", "transformers_bloom_for_causal_lm",
"transformers_falcon_for_causal_lm", "transformers_falcon_for_causal_lm",
"transformers_chatglm_for_conditional_generation", "transformers_chatglm_for_conditional_generation",
"transformers_llama_for_casual_lm", "transformers_llama_for_causal_lm",
"transformers_vit_for_masked_image_modeling", "transformers_vit_for_masked_image_modeling",
"transformers_mistral_for_casual_lm", "transformers_mistral_for_causal_lm",
] ]
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1" IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"

12
tests/kit/model_zoo/transformers/command.py

@ -32,8 +32,8 @@ if HAS_COMMAND:
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for causal lm
def data_gen_for_casual_lm(): def data_gen_for_causal_lm():
data = data_gen() data = data_gen()
labels = data["input_ids"].clone() labels = data["input_ids"].clone()
data["labels"] = labels data["labels"] = labels
@ -44,7 +44,7 @@ if HAS_COMMAND:
# function to get the loss # function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"] loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = CohereConfig( config = CohereConfig(
@ -70,10 +70,10 @@ if HAS_COMMAND:
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_command_for_casual_lm", name="transformers_command_for_causal_lm",
model_fn=lambda: transformers.CohereForCausalLM(config), model_fn=lambda: transformers.CohereForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm, data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm, loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )

41
tests/kit/model_zoo/transformers/llama.py

@ -33,20 +33,21 @@ if HAS_LLAMA:
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
] ]
).long() ).long()
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.Tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
).long()
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for causal lm
def data_gen_for_casual_lm(): def data_gen_for_causal_lm():
data = data_gen() data = data_gen()
# Test padded sequence
padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long)
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
ignore_idx = -100
labels = data["input_ids"].clone() labels = data["input_ids"].clone()
labels[~data["attention_mask"].bool()] = ignore_idx
data["labels"] = labels data["labels"] = labels
return data return data
@ -55,7 +56,7 @@ if HAS_LLAMA:
# function to get the loss # function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"] loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig( config = LlamaConfig(
@ -70,23 +71,23 @@ if HAS_LLAMA:
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
# register the following models # register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM, # transformers.LlamaForCausalLM,
# transformers.LlamaModel,
# transformers.LlamaForSequenceClassification, # transformers.LlamaForSequenceClassification,
model_zoo.register( model_zoo.register(
name="transformers_llama", name="transformers_llama_for_causal_lm",
model_fn=lambda: transformers.LlamaModel(config), model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen, data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn, loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_llama_for_casual_lm", name="transformers_llama",
model_fn=lambda: transformers.LlamaForCausalLM(config), model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen_for_casual_lm, data_gen_fn=data_gen,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(

2
tests/kit/model_zoo/transformers/mistral.py

@ -64,7 +64,7 @@ model_zoo.register(
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_mistral_for_casual_lm", name="transformers_mistral_for_causal_lm",
model_fn=lambda: transformers.MistralForCausalLM(config), model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm, data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,

2
tests/kit/model_zoo/transformers/mixtral.py

@ -53,6 +53,8 @@ config = MixtralConfig(
num_attention_heads=8, num_attention_heads=8,
num_hidden_layers=2, num_hidden_layers=2,
vocab_size=1000, vocab_size=1000,
attn_implementation="flash_attention_2",
torch_dtype="float16",
output_router_logits=True, output_router_logits=True,
) )

12
tests/kit/model_zoo/transformers/qwen2.py

@ -33,8 +33,8 @@ if HAS_QWEN2:
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for causal lm
def data_gen_for_casual_lm(): def data_gen_for_causal_lm():
data = data_gen() data = data_gen()
labels = data["input_ids"].clone() labels = data["input_ids"].clone()
data["labels"] = labels data["labels"] = labels
@ -45,7 +45,7 @@ if HAS_QWEN2:
# function to get the loss # function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"] loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = Qwen2Config( config = Qwen2Config(
@ -72,11 +72,11 @@ if HAS_QWEN2:
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_qwen2_for_casual_lm", name="transformers_qwen2_for_causal_lm",
model_fn=lambda: transformers.Qwen2ForCausalLM(config), model_fn=lambda: transformers.Qwen2ForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm, data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm, loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(

2
tests/test_booster/test_plugin/test_3d_plugin.py

@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
# TODO(ver217): add more models # TODO(ver217): add more models
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
"transformers_llama_for_casual_lm" "transformers_llama_for_causal_lm"
).items(): ).items():
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)

2
tests/test_booster/test_plugin/test_low_level_zero_plugin.py

@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
sub_model_zoo = model_zoo.get_sub_registry(model_name) sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None task_type = None
if name == "transformers_llama_for_casual_lm": if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM" task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification": if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS" task_type = "SEQ_CLS"

2
tests/test_booster/test_plugin/test_torch_ddp_plugin.py

@ -47,7 +47,7 @@ def check_torch_ddp_plugin():
registry = model_zoo registry = model_zoo
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
if name == "dlrm_interactionarch" or name.startswith("simple_"): if name in ("dlrm_interactionarch", "transformers_mixtral") or name.startswith("simple_"):
continue continue
run_fn(model_fn, data_gen_fn, output_transform_fn) run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

2
tests/test_checkpoint_io/test_gemini_checkpoint_io.py

@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [True, False]) @parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32]) @parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2]) @parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2]) @parameterize("zero_size", [2])

2
tests/test_checkpoint_io/test_gemini_torch_compability.py

@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("shard", [False, True]) @parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
def exam_torch_load_from_gemini(shard: bool, model_name: str): def exam_torch_load_from_gemini(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()

2
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py

@ -39,7 +39,7 @@ else:
@parameterize("shard", [True, False]) @parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32]) @parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS) @parameterize("test_config", TEST_CONFIGS)
@clear_cache_before_run() @clear_cache_before_run()

2
tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py

@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO(
if name != "transformers_llama": if name != "transformers_llama":
continue continue
task_type = None task_type = None
if name == "transformers_llama_for_casual_lm": if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM" task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification": if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS" task_type = "SEQ_CLS"

2
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py

@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("plugin_type", ["ddp", "zero", "gemini"]) @parameterize("plugin_type", ["ddp", "zero", "gemini"])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(

14
tests/test_lazy/test_models.py

@ -18,9 +18,17 @@ def test_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True) sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
for name, entry in sub_model_zoo.items(): for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models # TODO(ver217): lazy init does not support weight norm, skip these models
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( if name in (
("transformers_vit", "transformers_blip2", "transformers_whisper") "torchaudio_wav2vec2_base",
): "torchaudio_hubert_base",
"timm_beit",
"timm_vision_transformer",
"timm_deit",
"timm_beitv2",
"timm_deit3",
"timm_convit",
"timm_tnt_b_patch16_224",
) or name.startswith(("transformers_vit", "transformers_blip2", "transformers_whisper")):
continue continue
check_lazy_init(entry, verbose=True, default_device=default_device) check_lazy_init(entry, verbose=True, default_device=default_device)

2
tests/test_lora/test_lora.py

@ -91,7 +91,7 @@ def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None task_type = None
if name == "transformers_llama_for_casual_lm": if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM" task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification": if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS" task_type = "SEQ_CLS"

17
tests/test_pipeline/test_schedule/test_interleaved.py

@ -6,6 +6,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -107,13 +108,13 @@ def run_pp(
# check loss # check loss
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
# check gradients # check gradients
for i in range(num_model_chunk): for i in range(num_model_chunk):
idx = world_size * i + rank idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step # step
torch_optimizer.step() torch_optimizer.step()
@ -123,8 +124,8 @@ def run_pp(
# check updated param # check updated param
for i in range(num_model_chunk): for i in range(num_model_chunk):
idx = world_size * i + rank idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only # forward only
with torch.no_grad(): with torch.no_grad():
@ -135,14 +136,14 @@ def run_pp(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model: for layer in sharded_model:
if layer.weight.grad is None: if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None assert layer.weight.grad is None and layer.bias.grad is None
else: else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist @pytest.mark.dist

17
tests/test_pipeline/test_schedule/test_oneF_oneB.py

@ -6,6 +6,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check loss # check loss
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
# check gradients # check gradients
for i in range(len(sharded_model)): for i in range(len(sharded_model)):
idx = rank * num_local_layer + i idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step # step
torch_optimizer.step() torch_optimizer.step()
@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check updated param # check updated param
for i in range(len(sharded_model)): for i in range(len(sharded_model)):
idx = rank * num_local_layer + i idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only # forward only
with torch.no_grad(): with torch.no_grad():
@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model: for layer in sharded_model:
if layer.weight.grad is None: if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None assert layer.weight.grad is None and layer.bias.grad is None
else: else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist( def run_dist(

3
tests/test_shardformer/test_flash_attention.py

@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma
padding_mask = padding_mask[:, None, :, None].logical_not() padding_mask = padding_mask[:, None, :, None].logical_not()
ref_output = ref_output.masked_fill(padding_mask, 0) ref_output = ref_output.masked_fill(padding_mask, 0)
output = output.masked_fill(padding_mask, 0) output = output.masked_fill(padding_mask, 0)
assert_close(output, ref_output, **tols) assert_close(output, ref_output, **tols)
output.mean().backward() output.mean().backward()
ref_output.mean().backward() ref_output.mean().backward()
@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype):
attn_kwargs, padding_mask = gen_kwargs_func(dtype) attn_kwargs, padding_mask = gen_kwargs_func(dtype)
for attn_func, name, need_postprocess in attn_funcs: for attn_func, name, need_postprocess in attn_funcs:
print(f"{dtype}, {name}, {mask_type}") print(f"{dtype}, {name}, {mask_type}")
if mask_type == "padded":
pass
if need_postprocess: if need_postprocess:
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
else: else:

186
tests/test_shardformer/test_layer/test_ring_attn.py

@ -0,0 +1,186 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
@parameterize("seq_len", [4096])
@parameterize("bs", [2])
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
# and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main)
atol = rtol = 7e-3
# Setup inputs
qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
local_qkv = split_batch_zigzag(qkv, sp_group)
q, k, v = local_qkv.unbind(dim=-3)
q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D)
q.requires_grad = k.requires_grad = v.requires_grad = True
# Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention(
q,
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=max(2, sp_size // 2),
# inner_ring_size=4
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
)
# Checkout out and softmax denominator
local_out = split_batch_zigzag(out, sp_group)
local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1)
local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads)
assert_close(ring_lse, local_lse, atol=atol, rtol=rtol)
assert_close(ring_out, local_out, atol=atol, rtol=rtol)
# Check grads
ring_out.sum().backward()
out.sum().backward()
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
dqkv = qkv.grad
local_dqkv = split_batch_zigzag(dqkv, sp_group)
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
if dist.get_rank() == 0:
print(
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed."
)
@parameterize("seqlen", [4096])
@parameterize("bs", [2])
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device)
padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0
padding_mask[:, seqlen // 2 :] = 0
input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
# Forward
# out = ColoAttention.attention(q, k, v, **mask_info)
flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()]
qkv = torch.stack([flat_input] * 3, dim=1)
qkv.retain_grad()
input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds)
out, lse, _ = flash_attn_varlen_qkvpacked_func(
qkv,
mask_info["cu_seqlens"] * sp_size,
mask_info["max_seqlen"] * sp_size,
return_attn_probs=True,
causal=True,
# deterministic=True
)
# Test the splitting function
local_input = split_varlen_zigzag(
flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all()
del local_input, flat_input
q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)]
q_ring.retain_grad()
k_ring.retain_grad()
v_ring.retain_grad()
ring_out, ring_lse = RingAttention.attention(
q_ring,
k_ring,
v_ring,
sp_group,
**mask_info,
pad_output=False,
return_softmax=True,
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
# Check output
lse = lse.transpose(0, 1)
out, lse = split_varlen_zigzag(
[out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
assert_close(lse, ring_lse, atol=atol, rtol=rtol)
assert_close(out, ring_out, atol=atol, rtol=rtol)
# Check grads
labels = torch.ones(out.shape[0], dtype=dtype, device=device)
F.mse_loss(out.sum((-2, -1)), labels).backward()
F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward()
dq, dk, dv = [
split_varlen_zigzag(
qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
for i in range(3)
]
dq_ring, dk_ring, dv_ring = [
x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]]
for x in (q_ring.grad, k_ring.grad, v_ring.grad)
]
assert_close(dq, dq_ring, atol=atol, rtol=rtol)
assert_close(dk, dk_ring, atol=atol, rtol=rtol)
assert_close(dv, dv_ring, atol=atol, rtol=rtol)
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()
def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()
@rerun_if_address_is_in_use()
@parameterize("world_size", [2])
def test_ring_attn(world_size):
spawn(launch_single_ring, nprocs=world_size)
@rerun_if_address_is_in_use()
@parameterize("world_size", [4])
def test_double_ring(world_size):
spawn(launch_double_ring, nprocs=world_size)
if __name__ == "__main__":
test_ring_attn()
test_double_ring()

27
tests/test_shardformer/test_model/_utils.py

@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Module from torch.nn import Module
from torch.optim import Adam, Optimizer from torch.optim import Adam, Optimizer
from torch.testing import assert_close from torch.testing import assert_close
from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
org_output = org_model(**unshard_test_data) org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output) org_loss = criterion(org_output)
org_loss.backward() org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output return org_loss, org_output, sharded_loss, sharded_output
@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
def check_output_hidden_state( def check_output_hidden_state(
org_output: Tensor, org_output: BaseModelOutputWithPast,
sharded_output: Tensor, sharded_output: BaseModelOutputWithPast,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5, atol: float = 1e-5,
rtol: float = 1e-3, rtol: float = 1e-3,
shard_config: Optional[ShardConfig] = None,
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
@ -315,6 +316,14 @@ def check_output_hidden_state(
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
# Check if the output sequence is gathered before cross entropy
if shard_config is not None:
seq_dim = 1
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
shard_grad = torch.cat(shard_grad_list, dim=dim) shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel # embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]: try:
shard_grad = shard_grad[: org_grad.shape[0], :] if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
except:
pass
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
@ -404,9 +416,6 @@ def check_grad(
org_grad = getattr_(org_model, suffix).weight.grad org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight shard_weight = getattr_(sharded_model, suffix).weight
# if verbose and dist.get_rank() == 0:
# print("shard_weight", shard_weight)
# print("org_grad", org_grad)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group) dist.all_gather(shard_grad_list, shard_grad, tp_group)
@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
"org_grad": tensor to be compared from the original model "org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model "shard_grad": tensor to be compared from the sharded model
""" """
for suffix, check_info in check_tensors.items(): for idx, (suffix, check_info) in enumerate(check_tensors.items()):
org_grad = check_info["org_grad"] org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"] shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"] rtol = check_info["rtol"]

4
tests/test_shardformer/test_model/test_shard_command.py

@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
], ],
) )
def run_command_test(test_config): def run_command_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
@ -321,7 +321,7 @@ def run_command_test(test_config):
], ],
) )
def run_command_3d_test(test_config): def run_command_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

113
tests/test_shardformer/test_model/test_shard_llama.py

@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
): ):
master2working = sharded_optimizer.get_master_to_working_map() master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): for (name, p1), p2 in zip(
llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]
):
working_p = master2working[id(p2)] working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = ( grad_index = (
@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
grad = grads[grad_index] grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) try:
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
except Exception as e:
raise RuntimeError(f"Failed to check grad for {name}") from e
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
@ -114,75 +119,103 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "LlamaModel": if org_model.__class__.__name__ == "LlamaModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_output_hidden_state(
org_output,
sharded_output,
stage_manager,
atol=atol,
rtol=rtol,
shard_config=booster.plugin.shard_config,
)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights # check weights
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
try: check_weight(
check_weight( llama_model,
llama_model, shard_llama_model,
shard_llama_model, col_layer_for_check,
col_layer_for_check, tp_group,
tp_group, atol=atol,
atol=atol, rtol=rtol,
rtol=rtol, dim=1,
dim=1, verbose=False,
verbose=False, )
)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
# check grads # check grads
check_all_grad_tensors(grads_to_check) check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Ulysess + Flash attention # Double Ring Attention
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
"inner_ring_size": 2,
},
# Ring Attention + PP
{
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"sp_size": 2, "sp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "ring_attn",
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 0, "zero_stage": 1,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ # Test ring + Flash attention # Ring Attention + TP
{
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"sp_size": 2, "sp_size": 2,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring", "sequence_parallelism_mode": "ring_attn",
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 2, "zero_stage": 2,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ { # Ulysess + TP
"tp_size": 1, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"sp_size": 2, "sp_size": 2,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",
"enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 1, "zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + PP
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
@ -192,8 +225,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False, "enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"sp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
def run_llama_test(test_config): def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue
try: try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e: except Exception as e:
print(f"Failed config: {test_config}") print(f"Failed config: {test_config}, model name: {name}")
raise e raise e
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()

Loading…
Cancel
Save