mirror of https://github.com/hpcaitech/ColossalAI
[zerobubble] rebase main (#6075)
* fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [fp8] add fp8 comm for low level zero * [test] add zero fp8 test case * [Feature] llama shardformer fp8 support (#5938) * add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typo * [FP8] rebase main (#5963) * add SimPO * fix dataloader * remove debug code * add orpo * fix style * fix colossalai, transformers version * fix colossalai, transformers version * fix colossalai, transformers version * fix torch colossalai version * update transformers version * [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support * [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897) * add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint * fix style * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix eval * hotfix citation * [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api * fix orpo cross entropy loss * [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method * [ShardFormer] fix qwen2 sp (#5903) * [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility * fix object_to_tensor usage when torch>=2.3.0 (#5820) * [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug * [release] update version (#5912) * [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel * add kto * fix style, add kto data sample * [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B * refactor tokenization * [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931) * cannot access local variable 'default_conversation' where it is not associated with a value set default value for 'default_conversation' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix test data * refactor evaluation * remove real data path * remove real data path * Add n_fused as an input from native_module (#5894) * [FIX BUG] convert env param to int in (#5934) * [Hotfix] Fix ZeRO typo #5936 Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941) * Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix style * fix style * fix style * [shardformer] hotfix attn mask (#5945) * [shardformer] hotfix attn mask (#5947) * [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement * [zero] hotfix update master params (#5951) * [release] update version (#5952) * [Chat] Fix lora (#5946) * fix merging * remove filepath * fix style * Update README.md (#5958) * [hotfix] Remove unused plan section (#5957) * remove readme * fix readme * update * [test] add mixtral for sequence classification * [test] add mixtral transformer test * [moe] fix plugin * [test] mixtra pp shard test * [chore] handle non member group * [zero] solve hang * [test] pass mixtral shardformer test * [moe] implement transit between non moe tp and ep * [zero] solve hang * [misc] solve booster hang by rename the variable * solve hang when parallel mode = pp + dp * [moe] implement submesh initialization * [moe] add mixtral dp grad scaling when not all experts are activated * [chore] manually revert unintended commit * [chore] trivial fix * [chore] arg pass & remove drop token * [test] add mixtral modelling test * [moe] implement tp * [moe] test deepseek * [moe] clean legacy code * [Feature] MoE Ulysses Support (#5918) * moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [chore] minor fix * [moe] init moe plugin comm setting with sp * moe sp + ep bug fix * [moe] finalize test (no pp) * [moe] full test for deepseek and mixtral (pp + sp to fix) * [chore] minor fix after rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] solve moe ckpt test failure and some other arg pass failure * [moe] remove ops * [test] fix test: test_zero1_2 * [bug] fix: somehow logger hangs the program * [moe] deepseek moe sp support * [test] add check * [deepseek] replace attn (a workaround for bug in transformers) * [misc] skip redunant test * [misc] remove debug/print code * [moe] refactor mesh assignment * Revert "[moe] implement submesh initialization" This reverts commitpull/6077/head2f9bce6686
. * [chore] change moe_pg_mesh to private * [misc] remove incompatible test config * [misc] fix ci failure: change default value to false in moe plugin * [misc] remove useless condition * [chore] docstring * [moe] remove force_overlap_comm flag and add warning instead * [doc] add MoeHybridParallelPlugin docstring * [moe] solve dp axis issue * [chore] remove redundant test case, print string & reduce test tokens * [feat] Dist Loader for Eval (#5950) * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix * fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update low_level_optim.py --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Haze188 <haze188@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: zhurunhua <1281592874@qq.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: Gao, Ruiyuan <905370712@qq.com> Co-authored-by: hxwang <wang1570@e.ntu.edu.sg> Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Co-authored-by: HangXu <hangxu0304@gmail.com> * [fp8]support all2all fp8 (#5953) * support all2all fp8 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [fp8] add fp8 linear (#5967) * [fp8] add fp8 linear * [test] fix fp8 linear test condition * [test] fix fp8 linear test condition * [test] fix fp8 linear test condition * [fp8] support fp8 amp for hybrid parallel plugin (#5975) * [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility * fix (#5976) * [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928) * support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test ci]Feature/fp8 comm (#5981) * fix * fix * fix * [fp8] support gemini plugin (#5978) * [fp8] refactor hook * [fp8] support gemini plugin * [example] add fp8 option for llama benchmark * [fp8] use torch compile (torch >= 2.3.0) (#5979) * [fp8] use torch compile (torch >= 2.4.0) * [fp8] set use_fast_accum in linear * [chore] formal version check * [chore] fix sig * [fp8]Moe support fp8 communication (#5977) * fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [fp8] support hybrid parallel plugin (#5982) * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix * [fp8] refactor fp8 linear with compile (#5993) * [fp8] refactor fp8 linear with compile * [fp8] fix linear test * [fp8] fix linear test * [fp8] support asynchronous FP8 communication (#5997) * fix * fix * fix * support async all2all * support async op for all gather * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004) * [fp8] linear perf enhancement * [fp8]update reduce-scatter test (#6002) * fix * fix * fix * fix * [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009) * [fp8] zero support fp8 linear. (#6006) * fix * fix * fix * zero fp8 * zero fp8 * Update requirements.txt * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix the merge * fix the merge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix * fix * fix the merge * fix * fix * fix * fix * fix * fix the merge * fix * fix * fix * fix * [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016) * add SimPO * fix dataloader * remove debug code * add orpo * fix style * fix colossalai, transformers version * fix colossalai, transformers version * fix colossalai, transformers version * fix torch colossalai version * update transformers version * [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support * [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897) * add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint * fix style * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix eval * hotfix citation * [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api * fix orpo cross entropy loss * [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method * [ShardFormer] fix qwen2 sp (#5903) * [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility * fix object_to_tensor usage when torch>=2.3.0 (#5820) * [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug * [release] update version (#5912) * [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel * add kto * fix style, add kto data sample * [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B * refactor tokenization * [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931) * cannot access local variable 'default_conversation' where it is not associated with a value set default value for 'default_conversation' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix test data * refactor evaluation * remove real data path * remove real data path * Add n_fused as an input from native_module (#5894) * [FIX BUG] convert env param to int in (#5934) * [Hotfix] Fix ZeRO typo #5936 Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941) * Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix style * fix style * fix style * [shardformer] hotfix attn mask (#5945) * [shardformer] hotfix attn mask (#5947) * [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement * [zero] hotfix update master params (#5951) * [release] update version (#5952) * [Chat] Fix lora (#5946) * fix merging * remove filepath * fix style * Update README.md (#5958) * [hotfix] Remove unused plan section (#5957) * remove readme * fix readme * update * [test] add mixtral for sequence classification * [test] add mixtral transformer test * [moe] fix plugin * [test] mixtra pp shard test * [chore] handle non member group * [zero] solve hang * [test] pass mixtral shardformer test * [moe] implement transit between non moe tp and ep * [zero] solve hang * [misc] solve booster hang by rename the variable * solve hang when parallel mode = pp + dp * [moe] implement submesh initialization * [moe] add mixtral dp grad scaling when not all experts are activated * [chore] manually revert unintended commit * [chore] trivial fix * [chore] arg pass & remove drop token * [test] add mixtral modelling test * [moe] implement tp * [moe] test deepseek * [moe] clean legacy code * [Feature] MoE Ulysses Support (#5918) * moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [chore] minor fix * [moe] init moe plugin comm setting with sp * moe sp + ep bug fix * [moe] finalize test (no pp) * [moe] full test for deepseek and mixtral (pp + sp to fix) * [chore] minor fix after rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] solve moe ckpt test failure and some other arg pass failure * [moe] remove ops * [test] fix test: test_zero1_2 * [bug] fix: somehow logger hangs the program * [moe] deepseek moe sp support * [test] add check * [deepseek] replace attn (a workaround for bug in transformers) * [misc] skip redunant test * [misc] remove debug/print code * [moe] refactor mesh assignment * Revert "[moe] implement submesh initialization" This reverts commit2f9bce6686
. * [chore] change moe_pg_mesh to private * [misc] remove incompatible test config * [misc] fix ci failure: change default value to false in moe plugin * [misc] remove useless condition * [chore] docstring * [moe] remove force_overlap_comm flag and add warning instead * [doc] add MoeHybridParallelPlugin docstring * [moe] solve dp axis issue * [chore] remove redundant test case, print string & reduce test tokens * [feat] Dist Loader for Eval (#5950) * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix * Support overall loss, update KTO logging * [Docs] clarify launch port Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Hotfix] README link (#5966) * update ignore * update readme * run style * update readme * [Hotfix] Avoid fused RMSnorm import error without apex (#5985) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Chat] fix readme (#5989) * fix readme * fix readme, tokenization fully tested * fix readme, tokenization fully tested * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix sync condition (#6000) * [plugin] add cast inputs option for zero (#6003) * [pre-commit.ci] pre-commit autoupdate (#5995) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991) * [Feature] Zigzag Ring attention (#5905) * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update compatibility (#6008) * [misc] update compatibility * [misc] update requirements * [devops] disable requirements cache * [test] fix torch ddp test * [test] fix rerun on address in use * [test] fix lazy init * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix the merge * overlap kv comm with output rescale (#6017) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * fix the merge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix * fix * fix the merge * fix * [misc] Use dist logger in plugins (#6011) * use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * fix * fix * fix * fix * fix the merge * fix * fix * fix * fix --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Haze188 <haze188@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: zhurunhua <1281592874@qq.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: Gao, Ruiyuan <905370712@qq.com> Co-authored-by: hxwang <wang1570@e.ntu.edu.sg> Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train_dpo.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update low_level_zero_plugin.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018) * remove triton version * remove torch 2.2 * remove torch 2.1 * debug * remove 2.1 build tests * require torch >=2.2 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [plugin] hotfix zero plugin (#6036) * [plugin] hotfix zero plugin * [plugin] hotfix zero plugin * [Colossal-LLaMA] Refactor latest APIs (#6030) * refactor latest code * update api * add dummy dataset * update Readme * add setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update files * add PP support * update arguments * update argument * reorg folder * update version * remove IB infor * update utils * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update save for zero * update save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add apex * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add fused norm (#6038) * [FP8] unsqueeze scale to make it compatible with torch.compile (#6040) * [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020) * fix bug in load_state_dict_into_model; format error msg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py to support checking missing_keys * Update general_checkpoint_io.py fix bug in missing_keys error message * retrigger tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove deprecated install (#6042) * remove deprecated install * remove unused folder * [fp8] optimize all-gather (#6043) * [fp8] optimize all-gather * [fp8] fix all gather fp8 ring * [fp8] enable compile * [fp8] fix all gather fp8 ring * [fp8] fix linear hook (#6046) * [fp8] disable all_to_all_fp8 in intranode (#6045) * enhance all_to_all_fp8 with internode comm control * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable some fp8 ops due to performance issue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [release] update version (#6041) * [release] update version * [devops] update comp test * [devops] update comp test debug * [devops] debug comp test * [devops] debug comp test * [devops] debug comp test * [devops] debug comp test * [devops] debug comp test * [Feature] Split cross-entropy computation in SP (#5959) * halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048) * [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark * [fp8] hotfix backward hook (#6053) * [fp8] hotfix backward hook * [fp8] hotfix pipeline loss accumulation * [doc] update sp doc (#6055) * update sp doc * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix the sp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the attn * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * [fp8] fix missing fp8_comm flag in mixtral (#6057) * fix * fix * fix * [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059) * all_gather only internode, fix pytest * fix cuda arch <89 compile pytest error * fix pytest failure * disable all_gather_into_tensor_flat_fp8 * fix fp8 format * fix pytest * fix conversations * fix chunk tuple to list * [doc] FP8 training and communication document (#6050) * Add FP8 training and communication document * add fp8 docstring for plugins * fix typo * fix typo * fix * fix * [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063) * [ColossalEval] support for vllm (#6056) * support vllm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify vllm and update readme * run pre-commit * remove dupilicated lines and refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update param name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine code * update readme * refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [release] update version (#6062) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] fix poc format * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix mem check; * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [feat] moehybrid support zerobubble; * [fix] fix zerobubble pp for shardformer type input; * [fix] fix require_grad & deallocate call; * [fix] fix mem assert; * [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs' * [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch; * [fix] fix zerobubble; support shardformer model type; * [fix] fix test_pipeline_utils ci; * [plugin] hybrid support zero bubble pipeline (#6060) * hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <935724073@qq.com> * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] fix poc format * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [feat] update test; rm comments; * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix mem check; * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix mem assert; * [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs' * [plugin] hybrid support zero bubble pipeline (#6060) * hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <935724073@qq.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: HangXu <hangxu0304@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: GuangyaoZhang <xjtu521@qq.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Haze188 <haze188@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: zhurunhua <1281592874@qq.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: Gao, Ruiyuan <905370712@qq.com> Co-authored-by: hxwang <wang1570@e.ntu.edu.sg> Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Co-authored-by: wangbluo <2538539015@qq.com> Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
parent
af6aa9ed06
commit
295dd2d9fe
|
@ -1,4 +1,3 @@
|
|||
2.1.0-12.1.0
|
||||
2.2.2-12.1.0
|
||||
2.3.0-12.1.0
|
||||
2.4.0-12.4.1
|
||||
|
|
|
@ -89,7 +89,7 @@ jobs:
|
|||
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
|
||||
timeout-minutes: 90
|
||||
defaults:
|
||||
|
|
|
@ -12,7 +12,7 @@ jobs:
|
|||
if: github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
|
|
|
@ -64,7 +64,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
|
|
|
@ -58,7 +58,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
|
|
|
@ -52,7 +52,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
|
|
|
@ -51,4 +51,4 @@ jobs:
|
|||
|
||||
- name: Build
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
|
|
|
@ -56,7 +56,7 @@ jobs:
|
|||
needs: detect-changed-doc
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
|
@ -89,7 +89,7 @@ jobs:
|
|||
- name: Install ColossalAI
|
||||
run: |
|
||||
source activate pytorch
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
|
||||
- name: Test the Doc
|
||||
run: |
|
||||
|
|
|
@ -12,7 +12,7 @@ jobs:
|
|||
name: Test the changed Doc
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
|
@ -32,7 +32,7 @@ jobs:
|
|||
|
||||
- name: Install ColossalAI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
|
||||
- name: Install Doc Test Requirements
|
||||
run: |
|
||||
|
|
|
@ -45,7 +45,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
@ -53,7 +53,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
- name: Test the example
|
||||
run: |
|
||||
dir=${{ matrix.directory }}
|
||||
|
|
|
@ -9,6 +9,7 @@ on:
|
|||
paths:
|
||||
- "examples/**"
|
||||
- "!examples/**.md"
|
||||
- ".github/workflows/example_check_on_pr.yml"
|
||||
|
||||
jobs:
|
||||
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
|
||||
|
@ -89,7 +90,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm
|
||||
timeout-minutes: 30
|
||||
concurrency:
|
||||
|
@ -107,7 +108,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
|
||||
- name: Store Colossal-AI Cache
|
||||
run: |
|
||||
|
|
|
@ -34,7 +34,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
|
||||
- name: Traverse all files
|
||||
run: |
|
||||
|
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
|
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
|
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
volumes:
|
||||
- /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
|
||||
- /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
|
|
|
@ -420,7 +420,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
|
|||
## Installation
|
||||
|
||||
Requirements:
|
||||
- PyTorch >= 2.1
|
||||
- PyTorch >= 2.2
|
||||
- Python >= 3.7
|
||||
- CUDA >= 11.0
|
||||
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)
|
||||
|
|
|
@ -30,7 +30,7 @@ Colossal-LLaMA
|
|||
- [Install](#install)
|
||||
- [0. Pre-requisite](#0-pre-requisite)
|
||||
- [1. Install required packages](#1-install-required-packages)
|
||||
- [2. Install `xentropy`, `layer_norm` and `rotary`](#2-install-xentropy-layer_norm-and-rotary)
|
||||
- [2. Install Apex](#2-install-apex)
|
||||
- [How to run](#how-to-run)
|
||||
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
|
||||
- [2. Init Model Preparation](#2-init-model-preparation)
|
||||
|
@ -297,17 +297,13 @@ Here is details about CLI arguments:
|
|||
#### 1. Install required packages
|
||||
```
|
||||
cd Colossal-LLaMA
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
#### 2. Install `xentropy`, `layer_norm` and `rotary`
|
||||
|
||||
#### 2. Install Apex
|
||||
```bash
|
||||
git clone git@github.com:Dao-AILab/flash-attention.git
|
||||
# At the root folder
|
||||
cd csrc/xentropy && pip install .
|
||||
# At the root folder
|
||||
cd csrc/layer_norm && pip install .
|
||||
# At the root folder
|
||||
cd csrc/rotary && pip install .
|
||||
git clone git@github.com:NVIDIA/apex.git
|
||||
# Install from source.
|
||||
```
|
||||
|
||||
### How to run
|
||||
|
@ -427,25 +423,33 @@ Make sure master node can access all nodes (including itself) by ssh without pas
|
|||
Here is details about CLI arguments:
|
||||
* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
|
||||
* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
|
||||
* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
|
||||
* Booster plugin: `--plugin`. `ddp`,`gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
|
||||
* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
|
||||
* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
* Checkpoint directory: `--save_dir`. The directory path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
|
||||
* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
|
||||
* Configuration file: `--config_file`. The path to save the configuration file.
|
||||
* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
|
||||
* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1.
|
||||
* Batch size: `--batch_size`. Batch size per GPU. The default value is 1. For PP, it refers to number of samples per step.
|
||||
* Learning rate: `--lr`. The default value is 3e-4.
|
||||
* Max length: `--max_length`. Max context length. The default value is 4096.
|
||||
* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
* Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
||||
* Warmup steps: `-s`, `--warmup_steps`. The default value is calculated by 0.025 warmup ratio.
|
||||
* Weight decay: `--weight_decay`. The default value is 0.1.
|
||||
* Warmup steps: `--warmup_steps`. The default value is calculated by 0.025 warmup ratio.
|
||||
* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||
* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||
* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
|
||||
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
|
||||
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
|
||||
* Tensor parallelism size: `--tp`. TP size for 3d parallelism. The default value is 1. Used for 3d plugin.
|
||||
* Pipeline parallelism size: `--pp`. PP size for 3d parallelism. The default value is 1. Used for 3d plugin.
|
||||
* Sequence parallelism size: `--sp`. SP size for 3d parallelism. The default value is 1. Used for 3d plugin.
|
||||
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. Used for 3d plugin.
|
||||
* Sequence parallelism mode: `--sp_mode`. SP mode, used for 3d plugin. Choose from "split_gather", "ring", "all_to_all".
|
||||
* Switch for sequence parallelism: `--enable_sequence_parallelism`. Whether to enable SP, used for 3d plugin.
|
||||
* Zero CPU offload: `--zero_cpu_offload`. Whether to use offloading, used for 3d plugin.
|
||||
* Micro batch size: `--microbatch_size`. Batch size for each process in PP, used for 3d plugin.
|
||||
* Number of dummy sample: `--num_samples`. Number of samples for benchmarking.
|
||||
* Benchmark switch: `--benchmark`. Benchmark performance using random dataset.
|
||||
|
||||
##### 4.2 Arguments for Supervised Fine-tuning
|
||||
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(
|
||||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||
)
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
|
@ -1,352 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
if get_accelerator().name == "cuda":
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(
|
||||
output, pattern="... h d -> ... (h d)"
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
|
||||
elif get_accelerator().name == "npu":
|
||||
import torch_npu
|
||||
|
||||
class NPULlamaAttention(LlamaAttention):
|
||||
use_flash: bool = True
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.use_flash:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
attn_output, *_ = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
"BNSD",
|
||||
atten_mask=attention_mask.bool(),
|
||||
scale=self._softmax_scale,
|
||||
padding_mask=None,
|
||||
pre_tockens=65535,
|
||||
next_tockens=0,
|
||||
keep_prob=1.0,
|
||||
inner_precise=0,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum(
|
||||
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class NPURMSNorm(LlamaRMSNorm):
|
||||
def forward(self, hidden_states):
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.__class__ = NPULlamaAttention
|
||||
module.setup()
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.__class__ = NPURMSNorm
|
|
@ -0,0 +1,36 @@
|
|||
"""
|
||||
Utils for Colossal-LLaMA
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.booster import Plugin
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
if plugin is not None:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||
tensor.div_(plugin.dp_size)
|
||||
else:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
|
@ -1,15 +1,15 @@
|
|||
torch==2.1.2
|
||||
huggingface-hub
|
||||
packaging==24.0
|
||||
colossalai==0.3.6
|
||||
colossalai>=0.4.0
|
||||
autoflake==2.2.1
|
||||
black==23.9.1
|
||||
transformers==4.34.1
|
||||
transformers>=4.39.3
|
||||
tensorboard==2.14.0
|
||||
six==1.16.0
|
||||
datasets
|
||||
ninja==1.11.1
|
||||
flash-attn>=2.0.0,<=2.0.5
|
||||
flash-attn
|
||||
tqdm
|
||||
sentencepiece==0.1.99
|
||||
protobuf<=3.20.0
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def fetch_readme():
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def fetch_version():
|
||||
with open("version.txt", "r") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
setup(
|
||||
name="colossal_llama",
|
||||
version=fetch_version(),
|
||||
packages=find_packages(exclude=("*.egg-info",)),
|
||||
description="Continual Pre-training and SFT for LLaMA",
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
url="https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.7",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Environment :: GPU :: NVIDIA CUDA",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: System :: Distributed Computing",
|
||||
],
|
||||
)
|
|
@ -1,13 +1,20 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
# NCCL IB environment variables
|
||||
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_GID_INDEX=3
|
||||
export NCCL_IB_TIMEOUT=23
|
||||
export NCCL_IB_RETRY_CNT=7
|
||||
export OMP_NUM_THREADS=8
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME=""
|
||||
PARENT_SAVE_DIR=""
|
||||
|
|
|
@ -11,24 +11,24 @@ import resource
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama.dataset.dummy_dataset import RandomDataset
|
||||
from colossal_llama.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
)
|
||||
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
@ -36,109 +36,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor = tensor.data
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of the pre-trained modeling",
|
||||
)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Mixed precision",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument(
|
||||
"--use_grad_checkpoint",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use gradient checkpointing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neft",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use NEFTune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Freeze non embeddings parameters",
|
||||
)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=1)
|
||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||
parser.add_argument(
|
||||
"--skip_save_each_epoch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="skip saving the model checkpoint after each epoch is completed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
def train(args) -> None:
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
|
@ -147,21 +45,28 @@ def main() -> None:
|
|||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# Initialize Tensorboard and Save Config
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "gemini":
|
||||
if args.plugin == "ddp":
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -170,6 +75,8 @@ def main() -> None:
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -189,10 +96,18 @@ def main() -> None:
|
|||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=args.zero,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -210,24 +125,38 @@ def main() -> None:
|
|||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
||||
coordinator.print_on_master(
|
||||
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
|
||||
)
|
||||
|
||||
if args.benchmark:
|
||||
coordinator.print_on_master(f"Run benchmark with {args.num_samples} random samples.")
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.num_samples, max_length=args.max_length, vocab_size=tokenizer.vocab_size
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
seed=42,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(
|
||||
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
@ -241,7 +170,19 @@ def main() -> None:
|
|||
else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
|
@ -251,9 +192,6 @@ def main() -> None:
|
|||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
if args.use_flash_attn:
|
||||
replace_with_flash_attention(model=model)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
@ -342,6 +280,62 @@ def main() -> None:
|
|||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
|
||||
data_iter = iter(dataloader)
|
||||
step_bar = tqdm(
|
||||
range(len(dataloader)),
|
||||
desc="Step",
|
||||
disable=not (coordinator._local_rank == coordinator._world_size - 1),
|
||||
)
|
||||
for step in step_bar:
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if booster.plugin.stage_manager.is_last_stage():
|
||||
global_loss = all_reduce_mean(loss, plugin)
|
||||
if coordinator._local_rank == coordinator._world_size - 1:
|
||||
step_bar.set_postfix({"train/loss": global_loss.item()})
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save modeling.
|
||||
save_model_condition = args.save_interval > 0 and (step + 1) % args.save_interval == 0
|
||||
|
||||
if not args.skip_save_each_epoch:
|
||||
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||
|
||||
if save_model_condition and not args.benchmark:
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
accelerator.empty_cache()
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
else:
|
||||
pbar = tqdm(
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
|
@ -378,7 +372,6 @@ def main() -> None:
|
|||
pbar.update()
|
||||
|
||||
# Save modeling.
|
||||
|
||||
save_model_condition = (
|
||||
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||
)
|
||||
|
@ -386,7 +379,7 @@ def main() -> None:
|
|||
if not args.skip_save_each_epoch:
|
||||
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||
|
||||
if save_model_condition:
|
||||
if save_model_condition and not args.benchmark:
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
if args.use_neft:
|
||||
|
@ -402,7 +395,7 @@ def main() -> None:
|
|||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
batch_size=args.batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
|
@ -426,6 +419,7 @@ def main() -> None:
|
|||
deactivate_neftune(model, handle)
|
||||
|
||||
# Final save.
|
||||
if not args.benchmark:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
@ -434,4 +428,105 @@ def main() -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
# Basic training information.
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of the pre-trained model",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint for continuous training.")
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
# Training parameters
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Mixed precision",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument(
|
||||
"--use_grad_checkpoint",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use gradient checkpointing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neft",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use NEFTune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Freeze non embeddings parameters",
|
||||
)
|
||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||
parser.add_argument(
|
||||
"--skip_save_each_epoch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Skip saving the model checkpoint after each epoch is completed.",
|
||||
)
|
||||
|
||||
# Additional arguments for 3d plugin.
|
||||
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
|
||||
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
|
||||
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
|
||||
parser.add_argument(
|
||||
"--sp_mode",
|
||||
type=str,
|
||||
default="split_gather",
|
||||
choices=["split_gather", "ring", "all_to_all"],
|
||||
help="SP mode, used for 3d plugin.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_sequence_parallelism",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to enable SP, used for 3d plugin.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
|
||||
)
|
||||
|
||||
# Additional arguments for benchmark.
|
||||
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
|
||||
parser.add_argument(
|
||||
"--benchmark", action="store_true", default=False, help="Benchmark performance using random dataset."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
|
|
@ -1 +1 @@
|
|||
1.0.0
|
||||
1.1.0
|
||||
|
|
|
@ -102,21 +102,10 @@ More details can be found in the latest news.
|
|||
conda create -n colossal-chat python=3.10.9 (>=3.8.7)
|
||||
conda activate colossal-chat
|
||||
|
||||
# Install flash-attention
|
||||
git clone -b v2.0.5 https://github.com/Dao-AILab/flash-attention.git
|
||||
cd $FLASH_ATTENTION_ROOT/
|
||||
pip install .
|
||||
cd $FLASH_ATTENTION_ROOT/csrc/xentropy
|
||||
pip install .
|
||||
cd $FLASH_ATTENTION_ROOT/csrc/layer_norm
|
||||
pip install .
|
||||
cd $FLASH_ATTENTION_ROOT/csrc/rotary
|
||||
pip install .
|
||||
|
||||
# Clone Colossalai
|
||||
# Clone ColossalAI
|
||||
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||
|
||||
# Install ColossalAI
|
||||
# Install ColossalAI, make sure you have torch installed before using BUILD_EXT=1.
|
||||
cd $COLOSSAL_AI_ROOT
|
||||
BUILD_EXT=1 pip install .
|
||||
|
||||
|
|
|
@ -154,7 +154,7 @@ inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32
|
||||
}
|
||||
```
|
||||
|
@ -163,7 +163,7 @@ The `inference_kwargs` currently contains 5 fields:
|
|||
- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
|
||||
- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
|
||||
- `language` (str, compulsory): The language for the subcategory.
|
||||
- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
|
||||
- `calculate_overall_loss` (bool, compulsory): Whether to calculate the overall loss of sentences or not if the dataset is a pretrain dataset. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
|
||||
- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.
|
||||
|
||||
For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
|
||||
|
@ -230,7 +230,7 @@ Example:
|
|||
In this step, you will configure your tokenizer and model arguments to infer on the given datasets.
|
||||
|
||||
A config file consists of two parts.
|
||||
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
|
||||
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel`, `ChatGLMModel2` and `vLLMModel`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. `vLLMModel` is for models that can be loaded with vllm offline inference `LLM` class. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
|
||||
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
|
||||
|
||||
Once you have all config ready, the program will run inference on all the given datasets on all the given models.
|
||||
|
@ -272,7 +272,42 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM
|
|||
}
|
||||
```
|
||||
|
||||
Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
|
||||
An example config using model class `vLLMModel` and dataset class `CMMLUDataset` can be:
|
||||
```json
|
||||
{
|
||||
"model": [
|
||||
{
|
||||
"name": "model name",
|
||||
"model_class": "vLLMModel",
|
||||
"parameters": {
|
||||
"path": "path to model",
|
||||
"model_max_length": 2048,
|
||||
"tokenizer_path": "",
|
||||
"tokenizer_kwargs": {
|
||||
"trust_remote_code": true
|
||||
},
|
||||
"model_kwargs": {
|
||||
"trust_remote_code": true
|
||||
},
|
||||
"prompt_template": "plain",
|
||||
"batch_size": 4
|
||||
}
|
||||
}
|
||||
],
|
||||
"dataset": [
|
||||
{
|
||||
"name": "dataset name",
|
||||
"dataset_class": "CMMLUDataset",
|
||||
"debug": false,
|
||||
"few_shot": true,
|
||||
"path": "path to original dataset",
|
||||
"save_path": "path to save converted dataset"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Currently, we support Hugging Face models as well as vLLM models. For Hugging Face models, the `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. For vLLM model, the `tokenizer_kwargs` and `model_kwargs` are loaded together in `LLM` class.`few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
|
||||
|
||||
> For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation.
|
||||
|
||||
|
@ -287,7 +322,7 @@ torchrun --nproc_per_node=4 inference.py \
|
|||
--inference_save_path "path to save inference results"
|
||||
```
|
||||
|
||||
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size.
|
||||
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size (currently not support for `vLLMModel`).
|
||||
|
||||
### Evaluation
|
||||
|
||||
|
@ -530,10 +565,6 @@ class CustomizedModel(BaseModel):
|
|||
|
||||
Once you have successfully added your own model, you can specify your model class in your inference config.
|
||||
|
||||
## To do
|
||||
|
||||
- [ ] Add visualization code for evaluation results on public dataset
|
||||
- [ ] Improve the way to label target tokens
|
||||
|
||||
## Citations
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": False,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": False,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 256,
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": False,
|
||||
"all_classes": ["A", "B"],
|
||||
"language": LANGUAGE,
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "English",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 256,
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ class GSMDataset(BaseDataset):
|
|||
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
||||
|
||||
if forward_only:
|
||||
dataset[split][subject]["inference_kwargs"]["pretrain"] = True
|
||||
dataset[split][subject]["inference_kwargs"]["calculate_overall_loss"] = True
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data()
|
||||
|
|
|
@ -60,7 +60,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": True,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "English",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": False,
|
||||
"all_classes": None,
|
||||
"language": "English",
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 1024,
|
||||
"turns": 2,
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": False,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": LANGUAGE,
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ default_inference_kwargs = {
|
|||
"calculate_loss": False,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": LANGUAGE,
|
||||
"pretrain": False,
|
||||
"calculate_overall_loss": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .base import BaseModel
|
||||
from .chatglm import ChatGLM2Model, ChatGLMModel
|
||||
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
|
||||
from .vllm import vLLMModel
|
||||
|
||||
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
|
||||
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"]
|
||||
|
|
|
@ -28,7 +28,7 @@ class ChatGLMModel(HuggingFaceModel):
|
|||
|
||||
@torch.no_grad()
|
||||
def get_loss(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
@ -225,7 +225,7 @@ class ChatGLM2Model(ChatGLMModel):
|
|||
|
||||
@torch.no_grad()
|
||||
def get_loss(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
|
|
@ -105,6 +105,12 @@ class HuggingFaceModel(BaseModel):
|
|||
elif hasattr(self.tokenizer, "eod_id"):
|
||||
# Qwen has an eod token "<|endoftext|>".
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||
else:
|
||||
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
|
||||
raise ValueError(
|
||||
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
|
||||
"Please set pad_token_id manually."
|
||||
)
|
||||
|
||||
def _load_model(
|
||||
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
||||
|
@ -245,7 +251,7 @@ class HuggingFaceModel(BaseModel):
|
|||
return input_ids_list, labels_list, bytes_list
|
||||
|
||||
def _get_input_ids_and_labels(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
|
||||
) -> Tuple[List[torch.LongTensor]]:
|
||||
"""
|
||||
Get input_ids and labels for the given data.
|
||||
|
@ -258,7 +264,7 @@ class HuggingFaceModel(BaseModel):
|
|||
Input_ids and labels for the given batch.
|
||||
|
||||
"""
|
||||
if pretrain:
|
||||
if calculate_overall_loss:
|
||||
batch = []
|
||||
# Concatenate prompt and target answers.
|
||||
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
|
||||
|
@ -342,7 +348,7 @@ class HuggingFaceModel(BaseModel):
|
|||
calculate_loss = inference_kwargs["calculate_loss"]
|
||||
classes = inference_kwargs["all_classes"]
|
||||
language = inference_kwargs["language"]
|
||||
pretrain = inference_kwargs["pretrain"]
|
||||
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
|
||||
max_new_tokens = inference_kwargs["max_new_tokens"]
|
||||
few_shot_data = inference_kwargs.get("few_shot_data", None)
|
||||
|
||||
|
@ -384,12 +390,12 @@ class HuggingFaceModel(BaseModel):
|
|||
self.logger.info("-" * 120)
|
||||
self.logger.info(batch_prompt[0] + batch_target[0][0])
|
||||
|
||||
if not pretrain:
|
||||
if not calculate_overall_loss:
|
||||
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
|
||||
|
||||
if calculate_loss:
|
||||
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
|
||||
batch_prompt, batch_target, pretrain
|
||||
batch_prompt, batch_target, calculate_overall_loss
|
||||
)
|
||||
|
||||
probs = []
|
||||
|
@ -409,7 +415,7 @@ class HuggingFaceModel(BaseModel):
|
|||
]
|
||||
|
||||
for j in range(len(batch)):
|
||||
if not pretrain:
|
||||
if not calculate_overall_loss:
|
||||
if isinstance(batch[j]["output"], list):
|
||||
batch[j]["output"].append(batch_decodes[j].strip())
|
||||
else:
|
||||
|
@ -496,7 +502,9 @@ class HuggingFaceModel(BaseModel):
|
|||
return decoded_sequences, scores
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
|
||||
def get_loss(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
||||
|
@ -513,13 +521,15 @@ class HuggingFaceModel(BaseModel):
|
|||
# We don't need to generate new tokens.
|
||||
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||
if not pretrain:
|
||||
if not calculate_overall_loss:
|
||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||
|
||||
# Get the number of target answers for different questions
|
||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||
|
||||
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
|
||||
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(
|
||||
batch_prompt, batch_target, calculate_overall_loss
|
||||
)
|
||||
|
||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||
# We will generate new batches.
|
||||
|
|
|
@ -0,0 +1,498 @@
|
|||
import copy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .huggingface import HuggingFaceModel
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
class vLLMModel(HuggingFaceModel):
|
||||
"""
|
||||
Model wrapper around vLLM models.
|
||||
|
||||
Args:
|
||||
path: The path to a vLLM model.
|
||||
model_max_length: The maximum sequence length of the model.
|
||||
tokenizer_path: The path to the tokenizer.
|
||||
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism.
|
||||
quantization: The method used to quantize the model weights
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights.
|
||||
enforce_eager: Whether to enforce eager execution.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_max_length: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: Dict = None,
|
||||
model_kwargs: Dict = None,
|
||||
prompt_template: Conversation = None,
|
||||
batch_size: int = 1,
|
||||
logger: DistributedLogger = None,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
quantization: Optional[str] = None,
|
||||
gpu_memory_utilization: float = 0.5,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
path=path,
|
||||
model_max_length=model_max_length,
|
||||
prompt_template=prompt_template,
|
||||
batch_size=batch_size,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
self._load_model(
|
||||
path=path,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
tokenizer_path=tokenizer_path if tokenizer_path else None,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
quantization=quantization,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
)
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
path: str,
|
||||
model_kwargs: dict,
|
||||
tokenizer_kwargs: dict,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
quantization: Optional[str] = None,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
Args:
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||
tokenizer_path: The path to the tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism.
|
||||
quantization: The method used to quantize the model weights
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights.
|
||||
enforce_eager: Whether to enforce eager execution.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
|
||||
"""
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
model_kwargs.pop("torch_dtype")
|
||||
else:
|
||||
model_kwargs.setdefault("dtype", torch.float16)
|
||||
|
||||
if "trust_remote_code" in model_kwargs:
|
||||
trust_remote_code = model_kwargs["trust_remote_code"]
|
||||
model_kwargs.pop("trust_remote_code")
|
||||
|
||||
if "trust_remote_code" in tokenizer_kwargs:
|
||||
trust_remote_code = tokenizer_kwargs["trust_remote_code"]
|
||||
tokenizer_kwargs.pop("trust_remote_code")
|
||||
|
||||
self.model = LLM(
|
||||
model=path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
quantization=quantization,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**model_kwargs,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
self.tokenizer = self.model.get_tokenizer()
|
||||
|
||||
if self.batch_size > 1:
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.tokenizer.truncation_side = "left"
|
||||
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
|
||||
if self.tokenizer.eos_token:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
elif hasattr(self.tokenizer, "eod_id"):
|
||||
# Qwen has an eod token "<|endoftext|>".
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||
else:
|
||||
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
|
||||
raise ValueError(
|
||||
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
|
||||
"Please set pad_token_id manually."
|
||||
)
|
||||
|
||||
def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]:
|
||||
"""
|
||||
Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110
|
||||
|
||||
Args:
|
||||
input_ids_list: A batch of input string.
|
||||
labels: A batch of labels.
|
||||
|
||||
Returns:
|
||||
A list of loss and a list of label length.
|
||||
|
||||
"""
|
||||
batch_size = len(inputs)
|
||||
sampling_kwargs = SamplingParams(logprobs=1)
|
||||
outputs = self.model.generate(inputs, sampling_kwargs)
|
||||
ce_loss = []
|
||||
|
||||
if labels is not None:
|
||||
lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels]
|
||||
else:
|
||||
lens = [1] * batch_size
|
||||
|
||||
for i in range(batch_size):
|
||||
logprobs = outputs[i].outputs[0].logprobs
|
||||
token_ids = outputs[i].outputs[0].token_ids
|
||||
|
||||
logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))]
|
||||
logprobs_list = [i.logprob for i in logprobs_list]
|
||||
logprobs_list = np.array(logprobs_list)
|
||||
|
||||
if lens is not None:
|
||||
logprobs_list = logprobs_list[: lens[i]]
|
||||
|
||||
loss = -logprobs_list.sum(axis=-1) / lens[i]
|
||||
ce_loss.append(loss)
|
||||
|
||||
batch_loss = np.array(ce_loss)
|
||||
|
||||
return batch_loss, lens
|
||||
|
||||
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
|
||||
"""
|
||||
Infer the given data.
|
||||
This function will call self.generate() to get model outputs and use LogitsProcessor param to get specific logits.
|
||||
|
||||
Args:
|
||||
data: The data for inference.
|
||||
inference_kwargs: Arguments for inference.
|
||||
debug: Whether to display generated prompt for debugging.
|
||||
|
||||
Returns:
|
||||
Inference results.
|
||||
|
||||
"""
|
||||
calculate_loss = inference_kwargs["calculate_loss"]
|
||||
classes = inference_kwargs["all_classes"]
|
||||
language = inference_kwargs["language"]
|
||||
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
|
||||
max_new_tokens = inference_kwargs["max_new_tokens"]
|
||||
few_shot_data = inference_kwargs.get("few_shot_data", None)
|
||||
|
||||
# Some classification questions' options are texts not a single letter such as A, B, C and D.
|
||||
# If the text length is greater than 1, we won't calculate loss over choices.
|
||||
if classes is not None and any(len(c) > 1 for c in classes):
|
||||
classes = None
|
||||
|
||||
self.choices = classes
|
||||
self.indices_for_choices = None
|
||||
if self.choices:
|
||||
# Get indices for each choice
|
||||
self._get_choices_indices(language)
|
||||
|
||||
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
|
||||
|
||||
bar = tqdm(
|
||||
range(len(data_loader)),
|
||||
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
answers = []
|
||||
|
||||
for i, batch in enumerate(data_loader):
|
||||
batch_prompt, batch_target = get_batch_prompt(
|
||||
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
|
||||
)
|
||||
|
||||
if is_rank_0() and debug and i == 0:
|
||||
self.logger.info(
|
||||
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
|
||||
)
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info("An example prompt and prompt with target is:")
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info(batch_prompt[0])
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info(batch_prompt[0] + batch_target[0][0])
|
||||
|
||||
if not calculate_overall_loss:
|
||||
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
|
||||
|
||||
if calculate_loss:
|
||||
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
|
||||
batch_prompt, batch_target, calculate_overall_loss
|
||||
)
|
||||
|
||||
probs = []
|
||||
if self.indices_for_choices:
|
||||
scores = scores.to(torch.float32)
|
||||
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
|
||||
# Otherwise this will violate the single-choice setting.
|
||||
|
||||
if calculate_loss:
|
||||
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
|
||||
|
||||
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
|
||||
|
||||
probs = scores.numpy().tolist()
|
||||
probs = [
|
||||
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
|
||||
]
|
||||
|
||||
for j in range(len(batch)):
|
||||
if not calculate_overall_loss:
|
||||
if isinstance(batch[j]["output"], list):
|
||||
batch[j]["output"].append(batch_decodes[j].strip())
|
||||
else:
|
||||
batch[j]["output"] = batch_decodes[j].strip()
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
batch[j]["logits_over_choices"] = probs[j]
|
||||
|
||||
if calculate_loss:
|
||||
batch[j]["loss_over_choices"] = loss_over_choices[j]
|
||||
|
||||
if calculate_loss:
|
||||
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
|
||||
|
||||
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
|
||||
# However, loss (which is per sample loss) suffices for most cases.
|
||||
batch[j]["loss_sum"] = batch_losses[j]
|
||||
batch[j]["token_num"] = batch_target_token_nums[j]
|
||||
|
||||
if batch_bytes_nums:
|
||||
batch[j]["byte_num"] = batch_bytes_nums[j]
|
||||
answers.extend(batch)
|
||||
|
||||
bar.update()
|
||||
|
||||
return answers
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
|
||||
"""Generate results given a list of inputs and get logits of the first new token over choices.
|
||||
|
||||
Args:
|
||||
inputs: A list of strings.
|
||||
max_new_tokens: Max new tokens for generation.
|
||||
kwargs: Key arguments for generation
|
||||
|
||||
Returns:
|
||||
A list of generated strings and logits over choices.
|
||||
|
||||
Note:
|
||||
Currently the function only returns the logits of the first new token.
|
||||
It is used for single choice question.
|
||||
For multiple choices question, please avoid using the loss over choices.
|
||||
You should set argument choices as None in self.inference().
|
||||
|
||||
"""
|
||||
truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
|
||||
|
||||
generation_kwargs = kwargs.copy()
|
||||
generation_kwargs.update({"max_tokens": max_new_tokens})
|
||||
logits_processor = GetTokenLogitsProcessor(self.indices_for_choices)
|
||||
|
||||
sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs)
|
||||
|
||||
outputs = self.model.generate(truncated_inputs, sampling_kwargs)
|
||||
output_strs = []
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
output_strs.append(generated_text)
|
||||
scores = logits_processor.get_target_logits()
|
||||
|
||||
return output_strs, scores
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loss(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
||||
Args:
|
||||
batch: A batch of prompt without target answer.
|
||||
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
||||
|
||||
Returns:
|
||||
Loss.
|
||||
|
||||
"""
|
||||
|
||||
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
||||
# We don't need to generate new tokens.
|
||||
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||
if not calculate_overall_loss:
|
||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||
|
||||
# Get the number of target answers for different questions
|
||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||
|
||||
if calculate_overall_loss:
|
||||
batch = []
|
||||
bytes_list = []
|
||||
batch_prompt_pretrain = []
|
||||
for p, b in zip(batch_prompt, batch_target):
|
||||
batch.append(p + b[0])
|
||||
|
||||
for input in batch:
|
||||
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
|
||||
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
|
||||
# After all, the rest of the original string doesn't need to be tokenized at the first place.
|
||||
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
|
||||
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
|
||||
# After all, the rest of the original string doesn't need to be tokenized at the first place.
|
||||
ratio = [16, 8, 4, 2, 1]
|
||||
tokenized = None
|
||||
for r in ratio:
|
||||
tokenized = self.tokenizer(
|
||||
[input[0 : len(input) // r]],
|
||||
truncation=True,
|
||||
max_length=self.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if tokenized.input_ids.size(1) >= self.model_max_length:
|
||||
break
|
||||
|
||||
string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
|
||||
batch_prompt_pretrain.append(string)
|
||||
bytes_list.append(len(string.encode("utf-8")))
|
||||
|
||||
batch_prompt = copy.deepcopy(batch_prompt_pretrain)
|
||||
batch_target = None
|
||||
else:
|
||||
batch_prompt_processed = []
|
||||
batch_target_processed = []
|
||||
for prompt, targets in zip(batch_prompt, batch_target):
|
||||
for target in targets:
|
||||
target_tokenized = self.tokenizer(
|
||||
[target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
|
||||
)
|
||||
max_new_tokens = target_tokenized["input_ids"][0].size(0)
|
||||
prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
|
||||
batch_prompt_processed.append(prompt_with_correct_length)
|
||||
batch_target_processed.append(target)
|
||||
|
||||
batch_prompt = copy.deepcopy(batch_prompt_processed)
|
||||
batch_target = copy.deepcopy(batch_target_processed)
|
||||
bytes_list = None
|
||||
|
||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||
# We will generate new batches.
|
||||
losses = []
|
||||
target_token_nums = []
|
||||
|
||||
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target)
|
||||
losses.extend(losses_per_batch)
|
||||
target_token_nums.extend(target_token_num_per_batch)
|
||||
|
||||
start_indice = 0
|
||||
losses_per_sample = []
|
||||
|
||||
target_token_nums_per_sample = []
|
||||
bytes_nums_per_sample = []
|
||||
for length in batch_target_nums:
|
||||
losses_per_sample.append(losses[start_indice : start_indice + length])
|
||||
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
||||
|
||||
if bytes_list:
|
||||
bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
|
||||
|
||||
start_indice += length
|
||||
|
||||
if bytes_list:
|
||||
return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
|
||||
|
||||
return losses_per_sample, target_token_nums_per_sample, None
|
||||
|
||||
|
||||
class GetTokenLogitsProcessor:
|
||||
"""
|
||||
LogitsProcessor to get specific logits
|
||||
|
||||
Args:
|
||||
indices_for_choices: token indices of required tokens
|
||||
target_logits: store all the target logits
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indices_for_choices: List[List[int]],
|
||||
):
|
||||
self.indices_for_choices = (indices_for_choices,)
|
||||
self.target_logits = []
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
|
||||
choice_scores = []
|
||||
|
||||
if not input_ids:
|
||||
for option_indices in self.indices_for_choices[0]:
|
||||
choice_scores.append(logits[option_indices].detach().cpu())
|
||||
|
||||
choice_scores = torch.max(torch.stack(choice_scores), dim=0)[0]
|
||||
self.target_logits.append(choice_scores)
|
||||
|
||||
return logits
|
||||
|
||||
def get_target_logits(self) -> torch.Tensor:
|
||||
return torch.stack(self.target_logits) if self.target_logits else torch.tensor([])
|
|
@ -69,7 +69,7 @@ def rm_and_merge(
|
|||
os.remove(directory)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(len(answers["data"]))
|
||||
|
||||
all_answers[category] = answers
|
||||
|
||||
all_answers_with_dataset_class["inference_results"] = all_answers
|
||||
|
|
|
@ -10,3 +10,4 @@ matplotlib
|
|||
pandas
|
||||
seaborn
|
||||
scikit-learn
|
||||
vllm==0.5.5
|
||||
|
|
|
@ -323,7 +323,9 @@ class GeminiPlugin(DPPluginBase):
|
|||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -366,7 +368,9 @@ class GeminiPlugin(DPPluginBase):
|
|||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
|
@ -401,6 +405,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
fp8_communication=fp8_communication,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
|
|
@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
|||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
|
@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
self.use_ddp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
|
@ -112,8 +115,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
self.op_hooks = []
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8 or overlap_allgather:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
|
@ -209,7 +216,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
with self._wait_all_gather():
|
||||
with self._hook_context():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
|
@ -222,8 +229,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def _wait_all_gather(self):
|
||||
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
def _hook_context(self):
|
||||
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
|
@ -306,6 +313,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
with self.model._hook_context():
|
||||
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
|
@ -529,6 +537,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
None
|
||||
"""
|
||||
# Call the superclass backward method to compute gradients.
|
||||
with self.model._hook_context():
|
||||
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
|
@ -672,6 +681,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
|
@ -701,6 +711,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
backward_context=model._hook_context,
|
||||
)
|
||||
|
||||
def sync_dp_grads(self):
|
||||
|
@ -969,6 +981,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
||||
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
||||
|
@ -1021,6 +1035,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
inner_ring_size: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -1073,6 +1089,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.use_fp8 = use_fp8
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
if sequence_parallelism_mode == "ring_attn":
|
||||
|
@ -1131,6 +1148,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.scheduler = OneForwardOneBackwardSchedule(
|
||||
|
@ -1138,6 +1156,23 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
stage_manager=self.stage_manager,
|
||||
schedule=scheduler_nodes,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
stage_manager=self.stage_manager,
|
||||
schedule=scheduler_nodes,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
|
@ -1180,6 +1215,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
inner_ring_size=inner_ring_size,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
|
@ -1209,6 +1245,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
partition_grad=(self.zero_stage == 2),
|
||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
@ -1271,7 +1308,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1 and self.pp_size == 1
|
||||
)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
|
@ -1289,6 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
|
@ -1372,7 +1410,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
# so we disable it, performing manual reduction instead.
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
with ctx, model._wait_all_gather():
|
||||
with ctx, model._hook_context():
|
||||
outputs = self.scheduler.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
|
|
@ -34,6 +34,7 @@ from colossalai.interface.optimizer import DistributedOptim
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
@ -62,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum):
|
|||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
|
||||
self,
|
||||
module: nn.Module,
|
||||
precision: str,
|
||||
overlap_allgather: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
|
@ -75,11 +81,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
module = module.to(get_accelerator().get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
self.use_fp8 = use_fp8
|
||||
if self.dtype is not None and cast_inputs:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.op_hooks = []
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
|
@ -89,14 +100,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
with ctx:
|
||||
with self._hook_context():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def _hook_context(self):
|
||||
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
|
@ -314,6 +327,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True.
|
||||
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -337,6 +352,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
|
@ -360,12 +377,14 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger()
|
||||
self.cast_inputs = cast_inputs
|
||||
|
||||
self.use_fp8 = use_fp8
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
|
@ -484,6 +503,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
self.precision,
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
cast_inputs=self.cast_inputs,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
|
@ -504,7 +524,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
||||
optimizer, **zero_optim_kwargs, verbose=self.verbose
|
||||
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
|
||||
)
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
|
|
@ -65,13 +65,18 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
if dp_process_group is moe_dp_group:
|
||||
pg_param_list = {
|
||||
dp_process_group: list(model.parameters()),
|
||||
}
|
||||
else:
|
||||
pg_param_list = {
|
||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
}
|
||||
|
||||
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
|
||||
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
|
||||
if len(pg_param_list[moe_dp_group]) == 0:
|
||||
raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
@ -166,7 +171,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -216,6 +223,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
moe_dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
self.logger = get_dist_logger()
|
||||
if overlap_communication or zero_stage == 2:
|
||||
|
@ -339,6 +348,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -357,6 +367,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -415,6 +426,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
|
||||
if use_ddp:
|
||||
self.logger.warning(
|
||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
||||
|
@ -422,17 +440,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
raise ValueError(
|
||||
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
|
||||
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
|
||||
)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
|
@ -443,6 +455,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
|
@ -473,6 +486,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
is_zero = True
|
||||
if self.dp_size <= 1:
|
||||
self.logger.warning(
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
|
|
|
@ -169,6 +169,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -179,6 +180,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ddp_kwargs = dict(
|
||||
|
@ -189,6 +191,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
@ -228,6 +231,11 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
|
||||
|
||||
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
|
|
|
@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
sync_module_states: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(
|
||||
|
@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
else:
|
||||
|
@ -348,6 +350,19 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
|
||||
|
||||
patch_fsdp_params_comm_hook()
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
|
||||
|
||||
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
|
||||
|
||||
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
||||
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
self.logger.warning(
|
||||
|
|
|
@ -220,9 +220,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
error_msgs = [
|
||||
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in remain_keys))
|
||||
]
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
|
|
|
@ -381,9 +381,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||
if len(remain_keys) > 0:
|
||||
if strict:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
error_msgs = [
|
||||
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
|
||||
]
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
|
|
|
@ -553,10 +553,10 @@ def load_state_dict_into_model(
|
|||
|
||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, unexpected_keys, error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
if strict or len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
if load_sub_module:
|
||||
for name, child in module._modules.items():
|
||||
|
@ -570,9 +570,9 @@ def load_state_dict_into_model(
|
|||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in unexpected_keys)
|
||||
)
|
||||
error_msgs = [
|
||||
"Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))
|
||||
]
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
|
|
|
@ -116,9 +116,9 @@ class InferCheckpoint_io(GeneralCheckpointIO):
|
|||
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||
if len(remain_keys) > 0:
|
||||
if strict:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
error_msgs = [
|
||||
"Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
|
||||
]
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
|
|
|
@ -9,6 +9,7 @@ import os
|
|||
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
@ -64,6 +65,11 @@ def launch(
|
|||
|
||||
set_seed(seed)
|
||||
|
||||
try:
|
||||
torch._dynamo.config.optimize_ddp = world_size > 1
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
|
||||
|
|
|
@ -119,6 +119,10 @@ class FlashAttentionLoader(KernelLoader):
|
|||
]
|
||||
|
||||
|
||||
class FlashAttentionDaoLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionDaoCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ from torch import Tensor
|
|||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
|
||||
MOE_KERNEL = None
|
||||
|
||||
|
||||
|
@ -306,7 +308,7 @@ class EPGradScalerIn(torch.autograd.Function):
|
|||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad * ctx.ep_size
|
||||
grad.mul_(ctx.ep_size)
|
||||
return grad, None
|
||||
|
||||
|
||||
|
@ -326,7 +328,7 @@ class EPGradScalerOut(torch.autograd.Function):
|
|||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
grad.div_(ctx.ep_size)
|
||||
return grad, None
|
||||
|
||||
|
||||
|
@ -380,6 +382,7 @@ def _all_to_all(
|
|||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -392,6 +395,11 @@ def _all_to_all(
|
|||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
if fp8_communication:
|
||||
handle = all_to_all_single_fp8(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
|
||||
)
|
||||
else:
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
|
@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
|
|||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
|
|||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
return _all_to_all(
|
||||
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
|
@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -435,8 +447,6 @@ def all_to_all_uneven(
|
|||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert (
|
||||
inputs.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
|
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
|
@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_first_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
is_send,
|
||||
|
@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def send_backward_recv_backward(
|
||||
|
@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_last_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
is_send,
|
||||
|
@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def forward_step(
|
||||
|
@ -298,7 +318,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if self.stage_manager.is_last_stage():
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
accum_loss.add_(loss.data)
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
|
@ -378,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait until current input is received
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_batch:
|
||||
|
@ -440,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
@ -466,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input.
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
@ -510,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_obj, fwd_wait_handles = send_forward_recv_forward()
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
|
||||
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
|
||||
|
@ -531,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
# backward local grads
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
if not last_batch:
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
|
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
|
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
|
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
|
@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
|
@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_first = None
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
|
@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_first = None # must not fallback
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
input_tensor, _ = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
|
@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
@ -246,7 +273,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
accum_loss.add_(loss.data)
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map_hf(detach, output_obj))
|
||||
return loss
|
||||
|
|
|
@ -0,0 +1,842 @@
|
|||
import os
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||
SCALE_BYTES = 4
|
||||
try:
|
||||
cuda_arch = int("".join(str(i) for i in torch.cuda.get_device_capability()))
|
||||
except:
|
||||
cuda_arch = 0
|
||||
|
||||
|
||||
class Handle:
|
||||
def __init__(self, handles=[], remain_ops=None) -> None:
|
||||
self.handles = handles
|
||||
self.remain_ops = remain_ops
|
||||
|
||||
def wait(self):
|
||||
for handle in self.handles:
|
||||
handle.wait()
|
||||
if self.remain_ops:
|
||||
self.remain_ops()
|
||||
|
||||
|
||||
def process_group_is_intranode(pg):
|
||||
if pg is None:
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
pg = _get_default_group()
|
||||
|
||||
local_world_size = None
|
||||
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
|
||||
if var in os.environ:
|
||||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
if local_world_size is None:
|
||||
local_world_size = torch.cuda.device_count()
|
||||
|
||||
group_ranks = dist.get_process_group_ranks(pg)
|
||||
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
|
||||
return min(group_ranks_node_ids) == max(group_ranks_node_ids)
|
||||
|
||||
|
||||
def cast_to_fp8(
|
||||
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
|
||||
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
|
||||
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
|
||||
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if inp.numel() == 0:
|
||||
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
|
||||
else:
|
||||
if per_channel_scale:
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
scale_inv = per_channel_max / fp8_max
|
||||
else:
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
scale = fp8_max / per_tensor_max
|
||||
scale_inv = 1.0 / scale
|
||||
|
||||
if out is not None:
|
||||
ret = torch.mul(scale, inp.float(), out=out)
|
||||
else:
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, torch.unsqueeze(scale_inv, dim=0)
|
||||
|
||||
|
||||
def cast_from_fp8(
|
||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
||||
scale: scaling factor returned by cast_to_fp8 function.
|
||||
ret_type: the datatype of the returned tensor.
|
||||
Returns:
|
||||
torch.Tensor
|
||||
"""
|
||||
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if per_channel_scale:
|
||||
if out is not None:
|
||||
return torch.mul(scale_inv[:, None], inp.float(), out=out)
|
||||
else:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
if out is not None:
|
||||
return torch.mul(scale_inv, inp.float(), out=out)
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def _all_reduce_fp8(
|
||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
op: ReduceOp.SUM or ReduceOp.AVG
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
input_type = tensor.dtype
|
||||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
input_size = tensor.numel()
|
||||
flat_padded_x = tensor.flatten()
|
||||
|
||||
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
|
||||
if op == ReduceOp.AVG:
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||
gather_tensor_handle = dist.all_gather(
|
||||
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
|
||||
)
|
||||
|
||||
def cat_op():
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
|
||||
else:
|
||||
cat_op()
|
||||
|
||||
|
||||
def all_reduce_fp8(
|
||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
# fall back to default op due to performance issue
|
||||
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
||||
def _all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
input_type = input.dtype
|
||||
input_shape = input.shape
|
||||
input_device = input.device
|
||||
input = input.flatten()
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
if input_split_sizes is not None:
|
||||
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
|
||||
input_chunks = list(torch.split(inp, input_split_sizes))
|
||||
else:
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
|
||||
if output_split_sizes is not None:
|
||||
output_chunks = [
|
||||
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
|
||||
for i in range(world_size)
|
||||
]
|
||||
else:
|
||||
if dist.get_rank() == world_size - 1:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
|
||||
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
cast_output_chunk = [
|
||||
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||
]
|
||||
|
||||
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
||||
outputs_shape = list(input_shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
else:
|
||||
outputs_shape = input_shape
|
||||
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is wrapper for _all_to_all_single_fp8.
|
||||
"""
|
||||
if process_group_is_intranode(group):
|
||||
return dist.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=async_op,
|
||||
)
|
||||
else:
|
||||
return _all_to_all_single_fp8(
|
||||
output,
|
||||
input,
|
||||
fp8_format=fp8_format,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=async_op,
|
||||
)
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
The activations tensor is indexed by 'hidden_states' in the inp dict.
|
||||
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
|
||||
Metadata such as fp8_scale is saved into inp dict for communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
assert (
|
||||
inp["hidden_states"].size(-1) % 2 == 0
|
||||
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
|
||||
inp_tensor = inp["hidden_states"]
|
||||
inp_dtype = inp_tensor.dtype
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs())
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
if amax > finfo.max:
|
||||
fp8_type = torch.float8_e5m2
|
||||
fp8_view_type = torch.float16
|
||||
else:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
fp8_view_type = torch.bfloat16
|
||||
|
||||
finfo = torch.finfo(fp8_type)
|
||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||
q_tensor = inp_tensor.data.float() * scale
|
||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
"""
|
||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||
del_metadata = False is useful when this function is called before p2p communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
inp_tensor = inp["hidden_states"]
|
||||
scale = inp["fp8_scale"]
|
||||
|
||||
fp8_view_type = inp_tensor.dtype
|
||||
if fp8_view_type == torch.float16:
|
||||
fp8_type = torch.float8_e5m2
|
||||
elif fp8_view_type == torch.bfloat16:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
else:
|
||||
raise TypeError("Only float16, bfloat16 are implemented.")
|
||||
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
||||
del inp["dtype"]
|
||||
|
||||
|
||||
def _reduce_scatter_fp8(
|
||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
input_type = output.dtype
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
cast_input_list = []
|
||||
output_chunks = []
|
||||
output_scale_list = []
|
||||
for input in input_list:
|
||||
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
cast_input_list.append(ret)
|
||||
output_chunks.append(torch.empty_like(ret))
|
||||
output_scale_list.append(torch.empty_like(scale))
|
||||
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
|
||||
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(output_scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
output.data = summed_out
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def reduce_scatter_fp8(
|
||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
# fall back to default op due to performance issue
|
||||
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_async(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
fp8_format: str = "e5m2",
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
|
||||
|
||||
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
|
||||
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
|
||||
by the process group size.
|
||||
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
|
||||
to the input data type (such as ``float32``).
|
||||
|
||||
Example::
|
||||
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
|
||||
"""
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
|
||||
input_tensor = bucket.buffer()
|
||||
world_size = dist.get_world_size()
|
||||
input_type = input_tensor.dtype
|
||||
input_device = input_tensor.device
|
||||
flat_padded_x = input_tensor.flatten()
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
output_chunks_single = torch.empty_like(inp)
|
||||
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
|
||||
fut0 = dist.all_to_all_single(
|
||||
output_chunks_single,
|
||||
inp,
|
||||
output_split_sizes=split_sizes,
|
||||
input_split_sizes=split_sizes,
|
||||
group=group_to_use,
|
||||
async_op=True,
|
||||
).get_future()
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
fut1 = dist.all_gather_into_tensor(
|
||||
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
|
||||
|
||||
def sum_and_allgather(fut):
|
||||
output_chunks_single = fut.value()[0].wait()[0]
|
||||
scale_list_single = fut.value()[1].wait()[0]
|
||||
|
||||
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
|
||||
scale_list = scale_list_single.chunk(world_size, dim=0)
|
||||
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
|
||||
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
|
||||
fut2 = dist.all_gather_into_tensor(
|
||||
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
fut3 = dist.all_gather_into_tensor(
|
||||
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
fut_combined2 = torch.futures.collect_all([fut2, fut3])
|
||||
return fut_combined2
|
||||
|
||||
def decompress(fut):
|
||||
tensor_list_single = fut.value().wait()[0].value()[0]
|
||||
scale_list_single = fut.value().wait()[1].value()[0]
|
||||
|
||||
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
|
||||
scale_list = scale_list_single.chunk(world_size, dim=0)
|
||||
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
|
||||
input_tensor_size = input_tensor.numel()
|
||||
input_shape = input_tensor.shape
|
||||
out = out[:input_tensor_size]
|
||||
|
||||
input_tensor.copy_(out.view(input_shape).to(input_type))
|
||||
return input_tensor
|
||||
|
||||
return all_to_all_fut.then(sum_and_allgather).then(decompress)
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_sync(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
fp8_format="e5m2",
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
|
||||
This breaks the overlapping between allreduce communication and backward compuation.
|
||||
|
||||
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
|
||||
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
|
||||
"""
|
||||
|
||||
buffer = bucket.buffer()
|
||||
all_reduce_fp8(buffer, fp8_format=fp8_format)
|
||||
|
||||
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
|
||||
fut.set_result(bucket.buffer())
|
||||
|
||||
return fut
|
||||
|
||||
|
||||
def fp8_compress_fsdp_grad_comm_hook(
|
||||
state: object,
|
||||
unsharded_gradient_flattened: torch.Tensor,
|
||||
sharded_gradient: torch.Tensor,
|
||||
group=None,
|
||||
fp8_format="e5m2",
|
||||
) -> None:
|
||||
"""
|
||||
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
|
||||
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
|
||||
by using all_to_all and all_gather among the process group.
|
||||
|
||||
Example::
|
||||
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
||||
"""
|
||||
grad = unsharded_gradient_flattened
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
input_type = grad.dtype
|
||||
input_device = grad.device
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
|
||||
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
|
||||
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
|
||||
sharded_gradient.zero_()
|
||||
for tensor, scale in zip(buffer_list, scale_list):
|
||||
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
|
||||
|
||||
|
||||
def fp8_compress_fsdp_params_comm_hook(
|
||||
state: object,
|
||||
padded_unsharded_flat_param: torch.Tensor,
|
||||
sharded_flat_param: torch.Tensor,
|
||||
group=None,
|
||||
fp8_format="e5m2",
|
||||
) -> None:
|
||||
"""
|
||||
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
|
||||
|
||||
Example::
|
||||
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
||||
"""
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
inp = sharded_flat_param
|
||||
out = padded_unsharded_flat_param
|
||||
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
|
||||
|
||||
scale = fp8_max / per_tensor_max
|
||||
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
|
||||
|
||||
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
|
||||
dist.all_gather_into_tensor(
|
||||
fp8_out,
|
||||
fp8_sharded_flat_param,
|
||||
group=group,
|
||||
)
|
||||
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
|
||||
|
||||
|
||||
def split_chunk_by_channel(
|
||||
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
|
||||
):
|
||||
offset = chunk.numel() * rank
|
||||
end = offset + chunk.numel()
|
||||
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
|
||||
if len(break_points) == 0 or break_points[0] > offset:
|
||||
break_points.insert(0, offset)
|
||||
if break_points[-1] < end:
|
||||
break_points.append(end)
|
||||
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
|
||||
return chunk.split(sizes)
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
||||
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
world_size = dist.get_world_size(group)
|
||||
input_type = input_list[0].dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
input_tensor = input_list[i]
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
tensor_list.append(ret)
|
||||
|
||||
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
||||
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
|
||||
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
for i in range(world_size):
|
||||
scale = output_scale_list[i]
|
||||
tensor = output_tensor_list[i]
|
||||
tensor = tensor.view(fp8_type)
|
||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([tensor_hanle, scale_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
if process_group_is_intranode(group):
|
||||
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
||||
else:
|
||||
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
||||
def _all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_.dtype
|
||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_ = ret.view(torch.uint8)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
||||
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
|
||||
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
for i in range(world_size):
|
||||
output = tensor_list[i].view(fp8_type)
|
||||
scale = scale_list[i]
|
||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
if process_group_is_intranode(group):
|
||||
return dist.all_gather(output_list, input_, group=group, async_op=async_op)
|
||||
else:
|
||||
return _all_gather_fp8(output_list, input_, group=group, fp8_format=fp8_format, async_op=async_op)
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
||||
def all_gather_fp8_lagacy(
|
||||
output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
world_size = dist.get_world_size(group)
|
||||
shape = input_.shape
|
||||
input_type = input_.dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
||||
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
||||
cur_buffer = combined_buffers[dist.get_rank(group)]
|
||||
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
|
||||
for out, buf in zip(output_list, combined_buffers):
|
||||
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
|
||||
output = buf[SCALE_BYTES:].view(fp8_type)
|
||||
cast_from_fp8(output.view(shape), scale, input_type, out=out)
|
||||
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
|
||||
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
|
||||
# output = output.float() * scales
|
||||
# for i, out in enumerate(output_list):
|
||||
# out.copy_(output[i].view(shape))
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
||||
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
world_size = dist.get_world_size(group)
|
||||
rank = dist.get_rank(group)
|
||||
|
||||
send_rank = (rank + 1) % world_size
|
||||
recv_rank = (rank - 1) % world_size
|
||||
|
||||
shape = input_.shape
|
||||
input_type = input_.dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
||||
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
||||
cur_buffer = combined_buffers[dist.get_rank(group)]
|
||||
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||
|
||||
def send_recv(idx):
|
||||
send_idx = (rank - idx) % world_size
|
||||
recv_idx = (rank - idx - 1) % world_size
|
||||
ops = dist.batch_isend_irecv(
|
||||
[
|
||||
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
|
||||
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
|
||||
]
|
||||
)
|
||||
return ops
|
||||
|
||||
def cast(idx):
|
||||
cast_idx = (rank - idx - 1) % world_size
|
||||
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
|
||||
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
|
||||
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])
|
||||
|
||||
# warmup
|
||||
ops = send_recv(0)
|
||||
output_list[rank].copy_(input_)
|
||||
for op in ops:
|
||||
op.wait()
|
||||
ops = []
|
||||
|
||||
# 1p-1c
|
||||
for i in range(1, world_size - 1):
|
||||
new_ops = send_recv(i)
|
||||
for op in ops:
|
||||
op.wait()
|
||||
cast(i - 1)
|
||||
ops = new_ops
|
||||
|
||||
# cooldown
|
||||
for op in ops:
|
||||
op.wait()
|
||||
cast(world_size - 2)
|
||||
|
||||
|
||||
class _LinearFp8(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> Any:
|
||||
assert (
|
||||
x.dtype in (torch.bfloat16, torch.float16) and x.dtype == w.dtype
|
||||
), "Only float16 and bfloat16 are allowed."
|
||||
if bias is not None:
|
||||
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
|
||||
# ensure x and w are row-major
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
ctx.x_shape = x.shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.out_dtype = x.dtype
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
x_fp8, inv_scale_x = cast_to_fp8(x, fp8_format="e4m3")
|
||||
w_fp8, inv_scale_w = cast_to_fp8(w, fp8_format="e4m3")
|
||||
ctx.x_fp8 = x_fp8
|
||||
ctx.w_fp8_t = w_fp8.t()
|
||||
ctx.inv_scale_x = inv_scale_x
|
||||
ctx.inv_scale_w = inv_scale_w
|
||||
out = torch._scaled_mm(
|
||||
x_fp8,
|
||||
ctx.w_fp8_t,
|
||||
bias=bias,
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=inv_scale_x,
|
||||
scale_b=inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, out_grad) -> Any:
|
||||
out_grad = out_grad.reshape(-1, out_grad.shape[-1])
|
||||
out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2")
|
||||
x_grad = torch._scaled_mm(
|
||||
out_grad_fp8,
|
||||
ctx.w_fp8_t.contiguous().t(),
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
w_grad = torch._scaled_mm(
|
||||
out_grad_fp8.t().contiguous(),
|
||||
ctx.x_fp8.t().contiguous().t(),
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_x,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
bias_grad = None
|
||||
if ctx.has_bias:
|
||||
bias_grad = out_grad.sum(0)
|
||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
|
||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
||||
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out = _linear_fp8(input, weight, bias)
|
||||
return out
|
|
@ -0,0 +1,23 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
||||
class FP8Hook(ColoParamOpHook):
|
||||
def pre_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func):
|
||||
if func is F.linear:
|
||||
return linear_fp8
|
||||
return func
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
|
||||
from torch.distributed.utils import _p_assert
|
||||
|
||||
|
||||
def _all_gather_flat_param(
|
||||
self,
|
||||
padded_unsharded_flat_param: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
|
||||
|
||||
Then switch to use the all-gathered tensor.
|
||||
"""
|
||||
_p_assert(
|
||||
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
||||
"Expects a process group and world size to have been set via `shard()`",
|
||||
)
|
||||
sharded_flat_param = self.flat_param.data
|
||||
expected_numel = sharded_flat_param.numel() * self.world_size
|
||||
_p_assert(
|
||||
padded_unsharded_flat_param.numel() == expected_numel,
|
||||
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
||||
)
|
||||
|
||||
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
|
||||
|
||||
# HACK this should be handled by C10D
|
||||
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
|
||||
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
|
||||
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
|
||||
else:
|
||||
if self._comm_hook is None:
|
||||
dist.all_gather_into_tensor(
|
||||
padded_unsharded_flat_param,
|
||||
sharded_flat_param,
|
||||
pg,
|
||||
)
|
||||
else:
|
||||
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
|
||||
|
||||
if self._offload_params:
|
||||
# In case of offloading, `flat_param.data` (i.e. sharded param) is
|
||||
# created on the pre-unshard stream. We need to hand it over to the
|
||||
# unshard stream for all-gather
|
||||
_no_dispatch_record_stream(
|
||||
sharded_flat_param,
|
||||
self._device_handle.current_stream(), # unshard_stream
|
||||
)
|
||||
return padded_unsharded_flat_param
|
||||
|
||||
|
||||
def register_params_comm_hook(self, state: object, hook: callable):
|
||||
"""Register a communication hook for FlatParamHandle.
|
||||
|
||||
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
|
||||
parameters across multiple workers.
|
||||
|
||||
.. warning ::
|
||||
FSDP communication hook should be registered before running an initial forward pass
|
||||
and only once.
|
||||
|
||||
Args:
|
||||
state (object): Passed to the hook to maintain any state information during the training process.
|
||||
hook (Callable): Callable, which has one of the following signatures:
|
||||
1) ``hook: Callable[torch.Tensor] -> None``:
|
||||
This function takes in a Python tensor, which represents
|
||||
the full, flattened, unsharded gradient with respect to all variables
|
||||
corresponding to the model this FSDP unit is wrapping
|
||||
(that are not wrapped by other FSDP sub-units).
|
||||
It then performs all necessary processing and returns ``None``;
|
||||
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
|
||||
This function takes in two Python tensors, the first one represents
|
||||
the full, flattened, unsharded gradient with respect to all variables
|
||||
corresponding to the model this FSDP unit is wrapping
|
||||
(that are not wrapped by other FSDP sub-units). The latter
|
||||
represents a pre-sized tensor to store a chunk of a sharded gradient after
|
||||
reduction.
|
||||
In both cases, callable performs all necessary processing and returns ``None``.
|
||||
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
|
||||
Callables with signature 2 are expected to handle gradient communication for sharded cases.
|
||||
|
||||
"""
|
||||
if not self.check_is_root():
|
||||
raise AssertionError("register_comm_hook can only be called on a root instance.")
|
||||
|
||||
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
|
||||
# raise AssertionError(
|
||||
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
|
||||
# )
|
||||
if self._handle._comm_hook is not None:
|
||||
raise AssertionError("A communication hook is already registered")
|
||||
if not callable(hook):
|
||||
raise ValueError(f"The communication hook must be callable but got {hook}")
|
||||
self._handle._comm_hook = hook
|
||||
self._handle._comm_hook_state = state
|
||||
|
||||
|
||||
def patch_fsdp_params_comm_hook():
|
||||
if version.parse(torch.__version__) >= version.parse("2.2.0"):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._flat_param import FlatParamHandle
|
||||
|
||||
FlatParamHandle._comm_hook = None
|
||||
FlatParamHandle._comm_hook_state = None
|
||||
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
|
||||
FSDP.register_params_comm_hook = register_params_comm_hook
|
||||
else:
|
||||
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")
|
|
@ -16,6 +16,14 @@ try:
|
|||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
from colossalai.quantization.fp8 import (
|
||||
all_gather_fp8,
|
||||
all_reduce_fp8,
|
||||
all_to_all_fp8,
|
||||
all_to_all_single_fp8,
|
||||
reduce_scatter_fp8,
|
||||
)
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
r"""Layernorm
|
||||
|
@ -61,11 +69,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
||||
|
@ -78,6 +87,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
||||
weight = weight.view(weight.shape)
|
||||
|
@ -92,7 +102,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and fp8_communication:
|
||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
||||
elif ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
|
@ -101,10 +113,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
|
@ -113,11 +125,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
|
@ -129,6 +142,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
|
@ -144,10 +158,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(grad_input, group=ctx.process_group)
|
||||
else:
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
_ = torch.zeros(1, device=grad_input.device)
|
||||
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
|
@ -165,10 +180,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||
|
@ -236,17 +251,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
return _gather(input_, dim, process_group)
|
||||
return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
fp8_communication = ctx.fp8_communication
|
||||
# do reduce-scatter
|
||||
new_shape = list(grad_output.shape)
|
||||
assert (
|
||||
|
@ -257,9 +273,13 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2")
|
||||
else:
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
|
||||
return output, None, None
|
||||
return output, None, None, None
|
||||
|
||||
|
||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -550,9 +570,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
||||
ctx.dim = dim
|
||||
ctx.process_group = process_group
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
# do reduce-scatter
|
||||
new_shape = list(input_.shape)
|
||||
|
@ -562,6 +583,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
|
||||
else:
|
||||
dist.reduce_scatter(output, input_list, group=process_group)
|
||||
|
||||
return output
|
||||
|
@ -570,8 +594,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
return _gather(grad_output, dim, process_group), None, None
|
||||
return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None
|
||||
|
||||
|
||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -586,13 +611,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
|
||||
def forward(
|
||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
if ring is True:
|
||||
input_to_gather = {}
|
||||
|
@ -609,7 +637,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
)
|
||||
|
||||
else:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
|
||||
|
@ -624,6 +652,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
overlap = ctx.overlap
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
weight = weight.view(weight.shape)
|
||||
|
@ -631,7 +660,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
bias = bias.view(bias.shape)
|
||||
|
||||
if not overlap:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
|
@ -691,7 +720,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
|
@ -706,17 +735,25 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.fp8_communication = fp8_communication
|
||||
return _split(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||
|
||||
return (
|
||||
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class _ReduceForward(torch.autograd.Function):
|
||||
|
@ -730,15 +767,15 @@ class _ReduceForward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, grad_scale=None):
|
||||
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.grad_scale = grad_scale
|
||||
return _reduce(input_, process_group)
|
||||
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return grad_output, None, None
|
||||
return grad_output, None, None, None
|
||||
|
||||
|
||||
class _ReduceBackward(torch.autograd.Function):
|
||||
|
@ -751,13 +788,15 @@ class _ReduceBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.fp8_communication = fp8_communication
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.process_group), None
|
||||
fp8_communication = ctx.fp8_communication
|
||||
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
|
@ -770,17 +809,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
|
@ -794,26 +834,67 @@ class _AllToAll(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all_single(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
def backward(ctx, grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
fp8_communication = ctx.fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = grad_output.shape
|
||||
|
||||
if bsz == 1:
|
||||
return_grad = _all_to_all_single(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
else:
|
||||
return_grad = _all_to_all(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
|
||||
return (return_grad, None, None, None, None)
|
||||
|
||||
|
||||
class HookParameter(torch.autograd.Function):
|
||||
|
@ -839,10 +920,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
|
|||
return HookParameter.apply(input, weight, bias)
|
||||
|
||||
|
||||
def _reduce(input_, process_group):
|
||||
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
return input_
|
||||
else:
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_reduce(input_, group=process_group)
|
||||
return input_
|
||||
|
@ -868,18 +952,19 @@ def _split(input_, dim=-1, process_group=None):
|
|||
return output
|
||||
|
||||
|
||||
def _gather(input_, dim=-1, process_group=None):
|
||||
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
if fp8_communication:
|
||||
all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
|
||||
else:
|
||||
dist.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
@ -909,14 +994,19 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
|||
return output
|
||||
|
||||
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
if fp8_communication:
|
||||
all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all_single(
|
||||
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
|
||||
):
|
||||
inp_shape = list(input_.shape)
|
||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||
if scatter_dim < 2:
|
||||
|
@ -929,6 +1019,10 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|||
)
|
||||
|
||||
output = torch.empty_like(input_t)
|
||||
if fp8_communication:
|
||||
all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
|
||||
if scatter_dim < 2:
|
||||
|
@ -943,12 +1037,16 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|||
).contiguous()
|
||||
|
||||
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
return MatmulWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
return LinearWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
|
@ -959,12 +1057,12 @@ def linear_gather_forward_reducescatter_backward(
|
|||
)
|
||||
|
||||
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
|
||||
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
|
||||
|
@ -972,38 +1070,46 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
|
|||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
|
||||
):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
|
||||
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
|
||||
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def reduce_forward(input_, process_group, grad_scale=None):
|
||||
return _ReduceForward.apply(input_, process_group, grad_scale)
|
||||
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def reduce_backward(input_, process_group):
|
||||
return _ReduceBackward.apply(input_, process_group)
|
||||
def reduce_backward(input_, process_group, fp8_communication=False):
|
||||
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
||||
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode):
|
||||
def gather_sp_output(hidden_states, shard_config, sp_dim=1):
|
||||
"""
|
||||
Gather the output of the last layer for cross entropy computation
|
||||
"""
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
fp8_comm = shard_config.fp8_communication
|
||||
if dist.get_world_size(sp_group) == 1:
|
||||
return hidden_states
|
||||
|
||||
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
|
||||
)
|
||||
return hidden_states
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionDaoLoader,
|
||||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
|
@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger
|
|||
|
||||
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
||||
|
||||
MEMORY_BOUND = 10 * 1e9
|
||||
|
||||
__all__ = [
|
||||
"AttnMaskType",
|
||||
"ColoAttention",
|
||||
|
@ -77,6 +80,7 @@ def get_pad_info(
|
|||
|
||||
class ColoAttention:
|
||||
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||
|
||||
@staticmethod
|
||||
def _init_kernels_dispatch():
|
||||
|
@ -102,9 +106,11 @@ class ColoAttention:
|
|||
torch.bfloat16: half_dispatch_map,
|
||||
torch.float32: float_dispatch_map,
|
||||
}
|
||||
if ColoAttention._flash_kernel_dispatch is None:
|
||||
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
|
||||
|
||||
@staticmethod
|
||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
|
||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
||||
ColoAttention._init_kernels_dispatch()
|
||||
if (
|
||||
dtype not in ColoAttention._kernel_dispatch_map
|
||||
|
@ -113,11 +119,19 @@ class ColoAttention:
|
|||
raise ValueError(
|
||||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
||||
)
|
||||
|
||||
if size >= MEMORY_BOUND:
|
||||
if isinstance(ColoAttention._flash_kernel_dispatch, KernelLoader):
|
||||
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
|
||||
# lazy load
|
||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||
mask_type
|
||||
].load()
|
||||
|
||||
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
||||
return ColoAttention._flash_kernel_dispatch
|
||||
else:
|
||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||
|
||||
@staticmethod
|
||||
|
@ -154,6 +168,8 @@ class ColoAttention:
|
|||
return {}
|
||||
assert len(shape_4d) == 4 and shape_4d[1] == 1
|
||||
b, _, s_q, s_kv = shape_4d
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_size = s_q * s_kv * element_size
|
||||
outputs = {}
|
||||
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
|
||||
kv_padding_mask is None or kv_padding_mask.bool().all()
|
||||
|
@ -161,10 +177,13 @@ class ColoAttention:
|
|||
# no padding
|
||||
assert is_causal
|
||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||
if memory_size < MEMORY_BOUND:
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||
if s_q != 1:
|
||||
attention_mask = attention_mask.tril(diagonal=0)
|
||||
attention_mask.tril_(diagonal=0)
|
||||
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
||||
else:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
if kv_padding_mask is None:
|
||||
|
@ -177,7 +196,6 @@ class ColoAttention:
|
|||
b,
|
||||
s_kv,
|
||||
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
|
@ -190,10 +208,17 @@ class ColoAttention:
|
|||
)
|
||||
if is_causal:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||
if memory_size < MEMORY_BOUND:
|
||||
if s_q != 1:
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||
else:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||
if memory_size < MEMORY_BOUND:
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
|
||||
if invert:
|
||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||
outputs["attention_mask"] = attention_mask
|
||||
|
@ -278,8 +303,12 @@ class ColoAttention:
|
|||
assert attention_mask_type == AttnMaskType.CUSTOM
|
||||
|
||||
# kernel dispatch
|
||||
b, _, s_q, _ = q.shape
|
||||
b, _, s_kv, _ = v.shape
|
||||
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||
memory_size = s_q * s_kv * element_size
|
||||
mask_type = attention_mask_type if attention_mask is not None else None
|
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
||||
is_causal = attention_mask is not None and attention_mask_type in (
|
||||
AttnMaskType.CAUSAL,
|
||||
AttnMaskType.PADDED_CAUSAL,
|
||||
|
@ -433,7 +462,6 @@ class RingAttention(torch.autograd.Function):
|
|||
assert (
|
||||
sp_size % inner_ring_size == 0
|
||||
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(
|
||||
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||
|
@ -898,6 +926,7 @@ class RingAttention(torch.autograd.Function):
|
|||
|
||||
local_sp_rank = dist.get_rank(sp_group)
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
|
||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||
# cause NCCL "software caused connection abort" here...
|
||||
local_kv_comm = RingComm(local_kv_group)
|
||||
|
@ -1119,9 +1148,14 @@ class RingAttention(torch.autograd.Function):
|
|||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||
|
||||
Returns:
|
||||
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
mask_info: A dictionary of mask info.
|
||||
position_ids: Packed position ids of shape [..., Sq // sp_size].
|
||||
torch.Tensor:
|
||||
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||
|
||||
Dict[str, Any]:
|
||||
A dictionary containing mask info.
|
||||
|
||||
torch.Tensor:
|
||||
Packed position ids of shape [..., Sq // sp_size].
|
||||
|
||||
"""
|
||||
_load_varlen_helpers()
|
||||
|
|
|
@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
|
|||
gather_output: bool = True,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
|
|||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
@ -274,6 +278,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -282,6 +287,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
|
@ -390,5 +396,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
embedding_output = output_parallel.clone()
|
||||
embedding_output[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_forward(embedding_output, self.process_group)
|
||||
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
|
||||
return output
|
||||
|
|
|
@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule):
|
|||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -202,19 +204,25 @@ class Linear1D_Col(ParallelModule):
|
|||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -264,6 +272,7 @@ class Linear1D_Row(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -278,6 +287,7 @@ class Linear1D_Row(ParallelModule):
|
|||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -398,7 +408,9 @@ class Linear1D_Row(ParallelModule):
|
|||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
|
@ -416,10 +428,13 @@ class Linear1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
|
@ -562,6 +577,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
|||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# create weight and bias
|
||||
|
@ -592,6 +608,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
|||
**kwargs,
|
||||
new_num_embeddings=new_out_features,
|
||||
old_num_embeddings=out_features,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
# get the length of valid embeddings
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
|
|
@ -153,7 +153,6 @@ def dist_cross_entropy(
|
|||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
shard_config: ShardConfig,
|
||||
out_features: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
seq_dim: int = 1,
|
||||
|
@ -226,13 +225,13 @@ def dist_cross_entropy(
|
|||
logits,
|
||||
labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=out_features,
|
||||
vocab_size=vocab_size,
|
||||
dtype=dtype,
|
||||
mode="sum",
|
||||
)
|
||||
else:
|
||||
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
|
||||
logits = logits.view(-1, vocab_size)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
# Reduce loss instead of gathering logits over seq dim for savings
|
||||
|
|
|
@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -311,27 +313,50 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, self.async_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
True,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
self.async_communication,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -379,6 +404,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -392,6 +418,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
self.process_group = process_group
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -514,7 +541,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
|
@ -533,15 +562,26 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel,
|
||||
self.process_group,
|
||||
1,
|
||||
self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel,
|
||||
self.process_group,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
@ -600,6 +640,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
|
@ -611,6 +652,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -740,7 +782,9 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
|
|
@ -309,6 +309,9 @@ def split_batch_zigzag(
|
|||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if sp_size == 1:
|
||||
return batch
|
||||
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = [batch]
|
||||
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
||||
|
@ -364,6 +367,9 @@ def split_varlen_zigzag(
|
|||
"""
|
||||
sp_size = dist.get_world_size(sp_group)
|
||||
sp_rank = dist.get_rank(sp_group)
|
||||
if sp_size == 1:
|
||||
return batch
|
||||
|
||||
if is_2d:
|
||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||
|
||||
|
|
|
@ -187,11 +187,17 @@ class BertPipelineForwards:
|
|||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||
|
@ -242,7 +248,10 @@ class BertPipelineForwards:
|
|||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
|
@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
embedding_output = split_forward_gather_backward(
|
||||
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
embedding_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
|
@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
sequence_output = gather_forward_split_backward(
|
||||
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
sequence_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
|
|
@ -221,7 +221,10 @@ class BloomPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
@ -264,7 +267,10 @@ class BloomPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -359,12 +365,13 @@ class BloomPipelineForwards:
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.transformer.dtype,
|
||||
)
|
||||
|
||||
|
@ -922,7 +929,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -960,7 +970,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
@ -1024,8 +1037,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -13,10 +12,13 @@ from colossalai.shardformer import ShardConfig
|
|||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
gather_sp_output,
|
||||
is_share_sp_tp,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||
|
@ -138,6 +140,7 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: Optional[bool] = True,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
output_hidden_states = (
|
||||
|
@ -180,6 +183,15 @@ class ChatGLMPipelineForwards:
|
|||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Support SP + PP
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
|
||||
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= sp_size
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
if position_ids is not None:
|
||||
|
@ -200,12 +212,14 @@ class ChatGLMPipelineForwards:
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
# Keep the input split across all PP stages
|
||||
if stage_manager.is_first_stage():
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if sp_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
process_group=sp_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -214,6 +228,7 @@ class ChatGLMPipelineForwards:
|
|||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
)
|
||||
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
if output_hidden_states:
|
||||
|
@ -239,26 +254,19 @@ class ChatGLMPipelineForwards:
|
|||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if stage_manager.is_last_stage():
|
||||
# final layer_norm
|
||||
if self.encoder.post_layer_norm:
|
||||
hidden_states = self.encoder.final_layernorm(hidden_states)
|
||||
|
||||
# Gather seq-wise in the final output stage
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
|
@ -315,6 +323,7 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
@ -322,17 +331,21 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states = hidden_states[-1:]
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
lm_logits = lm_logits.to(torch.float32)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
# ChatGLM doesn't have lm_head split
|
||||
enable_tp = shard_config.enable_tensor_parallelism
|
||||
shard_config.enable_tensor_parallelism = False
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.transformer.output_layer.out_features,
|
||||
lm_logits.dtype,
|
||||
)
|
||||
shard_config.enable_tensor_parallelism = enable_tp
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -361,6 +374,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
force_sp_output_gather: Optional[bool] = True,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
|
@ -401,6 +415,12 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
if sp_mode in ["all_to_all"] and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
if sp_mode in ["all_to_all"] and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -414,6 +434,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
|
@ -421,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
|
@ -430,20 +452,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if sp_mode in ["split_gather"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
)
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
@ -532,9 +543,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
|||
key_layer = key_layer.reshape(sq, bs, -1)
|
||||
value_layer = value_layer.reshape(sq, bs, -1)
|
||||
|
||||
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
|
||||
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
|
||||
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
|
||||
query_layer = all_to_all_comm(
|
||||
query_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
key_layer = all_to_all_comm(
|
||||
key_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
value_layer = all_to_all_comm(
|
||||
value_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
query_layer = query_layer.view(
|
||||
sq * sp_size,
|
||||
|
@ -610,7 +636,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
|||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
if sp_mode == "all_to_all":
|
||||
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
|
||||
context_layer = all_to_all_comm(
|
||||
context_layer,
|
||||
sp_group,
|
||||
gather_dim=2,
|
||||
scatter_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
|
|
|
@ -17,14 +17,13 @@ from transformers.models.cohere.modeling_cohere import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
from ..layer._operation import gather_sp_output, is_share_sp_tp
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
|
@ -52,6 +51,7 @@ class CommandPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -93,10 +93,16 @@ class CommandPipelineForwards:
|
|||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
# NOTE: For generating full positions ids
|
||||
# (the states will be gathered along the seq dim before attention fwd).
|
||||
if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= shard_config.sequence_parallel_size
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
|
@ -136,12 +142,13 @@ class CommandPipelineForwards:
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -149,6 +156,7 @@ class CommandPipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
|
@ -206,21 +214,10 @@ class CommandPipelineForwards:
|
|||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -323,6 +320,7 @@ class CommandPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
|
@ -331,9 +329,10 @@ class CommandPipelineForwards:
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -384,9 +383,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -448,7 +447,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -476,6 +477,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -528,9 +530,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -574,10 +580,10 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -662,6 +668,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -669,12 +676,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
|
@ -24,14 +24,17 @@ from colossalai.moe._operation import (
|
|||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
|
||||
|
@ -57,11 +60,17 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
|||
return grad_output, grad_loss
|
||||
|
||||
|
||||
class EPDeepseekMoE(nn.Module):
|
||||
class EPDeepseekMoE(ParallelModule):
|
||||
def __init__(self):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
|
@ -70,6 +79,7 @@ class EPDeepseekMoE(nn.Module):
|
|||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
|
@ -86,13 +96,32 @@ class EPDeepseekMoE(nn.Module):
|
|||
self.tp_group = tp_group
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(
|
||||
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(
|
||||
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(
|
||||
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
self.shared_experts.gate_proj = Linear1D_Col.from_native_module(
|
||||
self.shared_experts.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
self.shared_experts.up_proj = Linear1D_Col.from_native_module(
|
||||
self.shared_experts.up_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
self.shared_experts.down_proj = Linear1D_Row.from_native_module(
|
||||
self.shared_experts.down_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module,
|
||||
|
@ -106,7 +135,8 @@ class EPDeepseekMoE(nn.Module):
|
|||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
fp8_communication = kwargs.get("fp8_communication", False)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -130,18 +160,32 @@ class EPDeepseekMoE(nn.Module):
|
|||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
dist.all_to_all_single(
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group=self.ep_group,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
|
||||
for i in range(1, self.ep_size):
|
||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||
activate_experts = (activate_experts > 0).float()
|
||||
|
||||
if self.fp8_communication:
|
||||
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
|
||||
else:
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states, _ = all_to_all_uneven(
|
||||
dispatch_states,
|
||||
input_split_list,
|
||||
output_split_list,
|
||||
self.ep_group,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
|
||||
if output_states.size(0) > 0:
|
||||
|
@ -167,7 +211,9 @@ class EPDeepseekMoE(nn.Module):
|
|||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
dispatch_states, _ = all_to_all_uneven(
|
||||
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
||||
|
@ -183,6 +229,79 @@ class EPDeepseekMoE(nn.Module):
|
|||
return output_hidden_states
|
||||
|
||||
|
||||
class DeepseekMoEGate_Col(ParallelModule):
|
||||
def parallel_linear(self, hidden_states):
|
||||
assert (
|
||||
hidden_states.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
hidden_states.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
|
||||
output = linear_with_async_comm(
|
||||
hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(
|
||||
output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = self.parallel_linear(hidden_states)
|
||||
if self.scoring_func == "softmax":
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(
|
||||
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
|
||||
).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module, process_group: ProcessGroup, config, gather_output, fp8_communication
|
||||
) -> "DeepseekMoEGate_Col":
|
||||
LazyInitContext.materialize(module)
|
||||
module.process_group = process_group
|
||||
module.fp8_communication = fp8_communication
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, module.weight)
|
||||
module.__class__ = DeepseekMoEGate_Col
|
||||
return module
|
||||
|
||||
|
||||
class DeepseekPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
|
@ -534,9 +653,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
|
@ -595,7 +714,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -669,6 +790,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
|
||||
self._use_flash_attention_2 = shard_config.enable_flash_attention
|
||||
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
|
@ -688,9 +810,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -734,9 +860,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -21,8 +21,9 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer import ColoAttention, RingAttention
|
||||
from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import dist_cross_entropy
|
||||
|
@ -39,10 +40,16 @@ def _get_attention_mask(
|
|||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
encoder_attention_mask: Optional[torch.FloatTensor],
|
||||
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
||||
batch_size, seq_len = hidden_states.shape[:2]
|
||||
# Received input is already split for non-first pipeline stages,
|
||||
# but attn mask isn't
|
||||
batch_size = hidden_states.size(0)
|
||||
seq_len = attention_mask.size(-1)
|
||||
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
if shard_config.enable_flash_attention:
|
||||
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
|
@ -62,6 +69,7 @@ def _get_attention_mask(
|
|||
encoder_attention_mask = {"attention_mask": None}
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# GPT2Attention mask.
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None and past_key_values[0] is not None:
|
||||
|
@ -69,6 +77,7 @@ def _get_attention_mask(
|
|||
if shard_config.enable_flash_attention:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
||||
hidden_states.dtype,
|
||||
|
@ -123,6 +132,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_gather: Optional[bool] = True,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
@ -146,16 +156,15 @@ class GPT2PipelineForwards:
|
|||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
|
@ -176,7 +185,7 @@ class GPT2PipelineForwards:
|
|||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
@ -190,9 +199,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states = hidden_states + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
attn_kwargs, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
|
@ -215,22 +222,43 @@ class GPT2PipelineForwards:
|
|||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
# Ring Attention's special zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if not attention_mask.bool().all():
|
||||
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, hidden_states, position_ids
|
||||
)
|
||||
else:
|
||||
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||
# Other sp modes
|
||||
else:
|
||||
if sp_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif sp_mode == "ring_attn":
|
||||
# Later stages already received split hidden states
|
||||
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||
del attention_mask
|
||||
|
||||
# Going through held blocks.
|
||||
if disable_pp:
|
||||
start_idx, end_idx = 0, len(self.h)
|
||||
else:
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
block = self.h[i]
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if torch.is_tensor(attn_kwargs):
|
||||
attn_kwargs = attn_kwargs.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
|
@ -241,7 +269,7 @@ class GPT2PipelineForwards:
|
|||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
attn_kwargs,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
|
@ -252,7 +280,7 @@ class GPT2PipelineForwards:
|
|||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=attn_kwargs,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
|
@ -269,25 +297,25 @@ class GPT2PipelineForwards:
|
|||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
# When sequence parallelism is done, gather the output tensor in forward and split it in backward
|
||||
gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if gather_output:
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
# gather_sp_output could've changed seq length.
|
||||
input_shape = (*input_shape[:-1], hidden_states.size(-2))
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
|
@ -364,16 +392,28 @@ class GPT2PipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_gather=False,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if (not disable_pp) and (not stage_manager.is_last_stage()):
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if shard_config.sequence_parallelism_mode == "ring_attn":
|
||||
# Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if not attention_mask.bool().all():
|
||||
# [B, max_seqlen // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
else:
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -768,7 +808,7 @@ class GPT2PipelineForwards:
|
|||
)
|
||||
|
||||
|
||||
def get_gpt2_flash_attention_forward():
|
||||
def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
def forward(
|
||||
|
@ -815,6 +855,21 @@ def get_gpt2_flash_attention_forward():
|
|||
if self.scale_attn_by_inverse_layer_idx:
|
||||
scale /= float(self.layer_idx + 1)
|
||||
dropout_p = self.attn_dropout.p if self.training else 0.0
|
||||
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if sp_mode == "ring_attn":
|
||||
attn_output = RingAttention.attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sp_group,
|
||||
**attention_mask,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
inner_ring_size=shard_config.inner_ring_size,
|
||||
)
|
||||
else:
|
||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
|
@ -826,464 +881,6 @@ def get_gpt2_flash_attention_forward():
|
|||
return forward
|
||||
|
||||
|
||||
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger = logging.get_logger(__name__)
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
def forward(
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_gpt2_mlp_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
|
||||
|
|
|
@ -185,6 +185,7 @@ class GPTJPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
|
@ -236,6 +237,7 @@ class GPTJPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
|
|
@ -25,7 +25,6 @@ from transformers.models.llama.modeling_llama import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
@ -58,10 +57,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||
# or get_lm_forward_with_dist_cross_entropy
|
||||
# Default to True to avoid bug when calling classification forward from huggingface
|
||||
force_sp_output_gather: bool = True,
|
||||
force_sp_gather: bool = True, # Set to false only when computing cross entropy
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -78,8 +74,9 @@ class LlamaPipelineForwards:
|
|||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
disable_pp = stage_manager is None
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if disable_pp or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
|
@ -88,10 +85,10 @@ class LlamaPipelineForwards:
|
|||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
device = hidden_states.device
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
|
@ -101,8 +98,8 @@ class LlamaPipelineForwards:
|
|||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
||||
# For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
|
||||
# Generating full positions ids for modes that gather sequence before attn
|
||||
if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()):
|
||||
seq_length *= sp_size
|
||||
|
||||
past_seen_tokens = 0
|
||||
|
@ -117,7 +114,6 @@ class LlamaPipelineForwards:
|
|||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
|
@ -130,14 +126,13 @@ class LlamaPipelineForwards:
|
|||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
|
||||
|
||||
no_split_input = disable_pp or not stage_manager.is_first_stage()
|
||||
if no_split_input and sp_mode == "ring_attn":
|
||||
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||
elif shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attn_kwargs = ColoAttention.prepare_attn_kwargs(
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
|
@ -146,15 +141,15 @@ class LlamaPipelineForwards:
|
|||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
else:
|
||||
attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
# Support SP + PP
|
||||
# TODO: support padded casual cu_seqlens across stages
|
||||
if stage_manager.is_first_stage():
|
||||
# Support SP + PP. Later stages have already received the split input.
|
||||
split_input = disable_pp or stage_manager.is_first_stage()
|
||||
if split_input:
|
||||
# Ring Attention zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||
if not attention_mask.bool().all():
|
||||
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, hidden_states, position_ids
|
||||
)
|
||||
|
@ -162,9 +157,13 @@ class LlamaPipelineForwards:
|
|||
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
|
@ -177,8 +176,8 @@ class LlamaPipelineForwards:
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
|
@ -224,16 +223,16 @@ class LlamaPipelineForwards:
|
|||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
|
||||
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager.is_last_stage():
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
|
@ -251,7 +250,7 @@ class LlamaPipelineForwards:
|
|||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
# always return dict for intermediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
|
@ -317,7 +316,7 @@ class LlamaPipelineForwards:
|
|||
# Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seqlen // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
|
@ -339,16 +338,17 @@ class LlamaPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
force_sp_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
disable_pp = stage_manager is None
|
||||
if disable_pp or stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -532,9 +532,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -605,7 +605,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -621,257 +623,3 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||
# or get_lm_forward_with_dist_cross_entropy
|
||||
# Default to True to avoid bug when calling classification forward from huggingface
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
invert=(sp_mode != "ring_attn"),
|
||||
)
|
||||
|
||||
else:
|
||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
# Ring Attention zigzag batch processing
|
||||
if sp_mode == "ring_attn":
|
||||
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||
attention_mask, sp_group, inputs_embeds, position_ids
|
||||
)
|
||||
else:
|
||||
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
|
||||
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attn_kwargs,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attn_kwargs,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
def forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
||||
# Special processing: Split labels in a zigzag fashion too
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
if attention_mask.bool().all():
|
||||
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||
else:
|
||||
# [B, max_seq_len // sp_size]
|
||||
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -274,10 +274,9 @@ class MistralForwards:
|
|||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -687,10 +686,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -31,12 +31,13 @@ from colossalai.moe._operation import (
|
|||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
@ -49,11 +50,17 @@ if is_flash_attn_2_available():
|
|||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
class EPMixtralSparseMoeBlock(ParallelModule):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
|
@ -62,6 +69,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.ep_group = ep_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if self.num_experts % self.ep_size != 0:
|
||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||
|
@ -80,9 +88,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
self.tp_group = tp_group
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
|
||||
expert.w1 = Linear1D_Col.from_native_module(
|
||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.w3 = Linear1D_Col.from_native_module(
|
||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.w2 = Linear1D_Row.from_native_module(
|
||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
@ -99,7 +113,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
fp8_communication = kwargs.get("fp8_communication", False)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -120,6 +135,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -127,12 +143,22 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
for i in range(1, self.ep_size):
|
||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||
activate_experts = (activate_experts > 0).float()
|
||||
|
||||
if self.fp8_communication:
|
||||
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
|
||||
else:
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states, _ = all_to_all_uneven(
|
||||
dispatch_states,
|
||||
input_split_list,
|
||||
output_split_list,
|
||||
self.ep_group,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
# compute expert output
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
|
@ -162,7 +188,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
output_states = torch.cat(output_states_list)
|
||||
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
dispatch_states, _ = all_to_all_uneven(
|
||||
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
|
@ -566,9 +594,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -673,7 +701,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -780,9 +810,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -831,9 +865,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -330,12 +330,13 @@ class OPTPipelineForwards:
|
|||
)
|
||||
if stage_manager.is_last_stage():
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.decoder.dtype,
|
||||
)
|
||||
|
||||
|
@ -955,9 +956,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -32,14 +32,12 @@ except ImportError:
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
from ..layer._operation import gather_sp_output
|
||||
from ..layer.utils import is_share_sp_tp
|
||||
|
||||
|
||||
class Qwen2PipelineForwards:
|
||||
|
@ -64,6 +62,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -115,6 +114,14 @@ class Qwen2PipelineForwards:
|
|||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# Support SP + PP
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
|
||||
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
|
||||
seq_length *= sp_size
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
|
@ -151,7 +158,6 @@ class Qwen2PipelineForwards:
|
|||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
|
@ -160,7 +166,6 @@ class Qwen2PipelineForwards:
|
|||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
|
@ -169,19 +174,20 @@ class Qwen2PipelineForwards:
|
|||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
if stage_manager.is_first_stage():
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if is_share_sp_tp(sp_mode):
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
process_group=sp_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
|
@ -239,21 +245,10 @@ class Qwen2PipelineForwards:
|
|||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
@ -347,15 +342,18 @@ class Qwen2PipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
if hidden_states.shape[1] == 2:
|
||||
pass
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -516,9 +514,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
value_states = self.v_proj(hidden_states)
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -537,7 +535,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
|
@ -604,7 +601,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -629,6 +628,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
force_sp_output_gather: bool = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -702,9 +702,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
next_decoder_cache = None
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
|
@ -740,10 +744,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, shard_config)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -820,14 +823,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -98,6 +98,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -106,6 +107,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -114,6 +116,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -123,7 +126,10 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -136,12 +142,16 @@ class BertPolicy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
|
@ -180,6 +190,13 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs=(
|
||||
{
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {}
|
||||
),
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -249,6 +266,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
|
|
|
@ -72,20 +72,30 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.projection",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -114,14 +124,23 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
|
@ -130,6 +149,9 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -138,14 +160,23 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.dropout",
|
||||
|
@ -154,6 +185,9 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dropout",
|
||||
|
@ -162,10 +196,16 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="intermediate_query.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dropout",
|
||||
|
@ -185,26 +225,44 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -225,7 +283,14 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -241,6 +306,7 @@ class BlipPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue