* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* [feat] moehybrid support zerobubble;
* [fix] fix zerobubble pp for shardformer type input;
* [feat] add more test;
* [fix] fix require_grad & deallocate call;
* [fix] updatw bwd b&w input; dict --> list[torch.Tensor]
* [fix] fix bwd w input;
* [fix] fix mem assert;
* [fix] fix input_tensors buffer append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic;
* [fix] use tree_flatten replace dict traverse;
* [fix] rm comments;
* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;
* [fix] fix detach clone release order;
* [fix] fix ci --> oom in 4096 hidden dim;
* [fix] fix dumb clone;
* [fix] fix detach_output_obj clone;
* [fix] fix stage_indices;
* [fix] fix traverse; traverse dict --> traverse tensor List;
* [fix] fix zerobubble; support shardformer model type;
* [fix] rm comments;
* [fix] fix test_pipeline_utils ci;
* [fix] remove duplicate arg; rm comments;
* [fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx;
* [fix] rm print & comments;
* [plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* hybrid support zbv
* fix
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
* [feat] zerobubble support moehybridplugin;
* [feat] update optimizer bwd; ä¸
* [fix] fix build ci;
* [zerobubble] rebase main (#6075)
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [fp8] add fp8 comm for low level zero
* [test] add zero fp8 test case
* [Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8
* Llama Shardformer Parity
* fix typo
* fix all reduce
* fix pytest failure
* fix reduce op and move function to fp8.py
* fix typo
* [FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
* [fp8]support all2all fp8 (#5953)
* support all2all fp8
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [fp8] add fp8 linear (#5967)
* [fp8] add fp8 linear
* [test] fix fp8 linear test condition
* [test] fix fp8 linear test condition
* [test] fix fp8 linear test condition
* [fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin
* [test] add fp8 hook test
* [fp8] fix fp8 linear compatibility
* fix (#5976)
* [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* implement communication hook for FSDP params all-gather
* added unit test for fp8 operators
* support fp8 communication in GeminiPlugin
* update training scripts to support fsdp and fp8 communication
* fixed some minor bugs observed in unit test
* add all_gather_into_tensor_flat_fp8
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add skip the test if torch < 2.2.0
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add skip the test if torch < 2.2.0
* add skip the test if torch < 2.2.0
* add fp8_comm flag
* rebase latest fp8 operators
* rebase latest fp8 operators
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [test ci]Feature/fp8 comm (#5981)
* fix
* fix
* fix
* [fp8] support gemini plugin (#5978)
* [fp8] refactor hook
* [fp8] support gemini plugin
* [example] add fp8 option for llama benchmark
* [fp8] use torch compile (torch >= 2.3.0) (#5979)
* [fp8] use torch compile (torch >= 2.4.0)
* [fp8] set use_fast_accum in linear
* [chore] formal version check
* [chore] fix sig
* [fp8]Moe support fp8 communication (#5977)
* fix
* support moe fp8
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
fix
fi
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model
* support fp8 comm for qwen2 model
* support fp8 comm for qwen2 model
* fp8
* fix
* bert and bloom
* chatglm and command
* gpt2,gptj,bert, falcon,blip2
* mistral,opy,sam,t5,vit,whisper
* fix
* fix
* fix
* [fp8] refactor fp8 linear with compile (#5993)
* [fp8] refactor fp8 linear with compile
* [fp8] fix linear test
* [fp8] fix linear test
* [fp8] support asynchronous FP8 communication (#5997)
* fix
* fix
* fix
* support async all2all
* support async op for all gather
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)
* [fp8] linear perf enhancement
* [fp8]update reduce-scatter test (#6002)
* fix
* fix
* fix
* fix
* [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
* [fp8] zero support fp8 linear. (#6006)
* fix
* fix
* fix
* zero fp8
* zero fp8
* Update requirements.txt
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
* [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update train_dpo.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update low_level_zero_plugin.py
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018)
* remove triton version
* remove torch 2.2
* remove torch 2.1
* debug
* remove 2.1 build tests
* require torch >=2.2
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [plugin] hotfix zero plugin (#6036)
* [plugin] hotfix zero plugin
* [plugin] hotfix zero plugin
* [Colossal-LLaMA] Refactor latest APIs (#6030)
* refactor latest code
* update api
* add dummy dataset
* update Readme
* add setup
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update files
* add PP support
* update arguments
* update argument
* reorg folder
* update version
* remove IB infor
* update utils
* update readme
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update save for zero
* update save
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add apex
* update
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add fused norm (#6038)
* [FP8] unsqueeze scale to make it compatible with torch.compile (#6040)
* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)
* fix bug in load_state_dict_into_model; format error msg
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update utils.py
to support checking missing_keys
* Update general_checkpoint_io.py
fix bug in missing_keys error message
* retrigger tests
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hotfix] Remove deprecated install (#6042)
* remove deprecated install
* remove unused folder
* [fp8] optimize all-gather (#6043)
* [fp8] optimize all-gather
* [fp8] fix all gather fp8 ring
* [fp8] enable compile
* [fp8] fix all gather fp8 ring
* [fp8] fix linear hook (#6046)
* [fp8] disable all_to_all_fp8 in intranode (#6045)
* enhance all_to_all_fp8 with internode comm control
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* disable some fp8 ops due to performance issue
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [release] update version (#6041)
* [release] update version
* [devops] update comp test
* [devops] update comp test debug
* [devops] debug comp test
* [devops] debug comp test
* [devops] debug comp test
* [devops] debug comp test
* [devops] debug comp test
* [Feature] Split cross-entropy computation in SP (#5959)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
* adapt chatglm, command-R, qwen
* debug
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
* add comments
* q1 index only once
* remove events to simplify stream sync
* simplify forward/backward logic
* 2d ring forward passed
* 2d ring backward passed
* fixes
* fix ring attn loss
* 2D ring backward + llama passed
* merge
* update logger
* fix typo
* rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* remove typos
* fixes
* support GPT
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins
* [example] add mixtral benchmark
* [moe] refine assertion and check
* [moe] fix mixtral & add more tests
* [moe] consider checking dp * sp group and moe_dp_group
* [mixtral] remove gate tp & add more tests
* [deepseek] fix tp & sp for deepseek
* [mixtral] minor fix
* [deepseek] add deepseek benchmark
* [fp8] hotfix backward hook (#6053)
* [fp8] hotfix backward hook
* [fp8] hotfix pipeline loss accumulation
* [doc] update sp doc (#6055)
* update sp doc
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix the sp
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the attn
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* [fp8] fix missing fp8_comm flag in mixtral (#6057)
* fix
* fix
* fix
* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)
* all_gather only internode, fix pytest
* fix cuda arch <89 compile pytest error
* fix pytest failure
* disable all_gather_into_tensor_flat_fp8
* fix fp8 format
* fix pytest
* fix conversations
* fix chunk tuple to list
* [doc] FP8 training and communication document (#6050)
* Add FP8 training and communication document
* add fp8 docstring for plugins
* fix typo
* fix typo
* fix
* fix
* [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)
* [ColossalEval] support for vllm (#6056)
* support vllm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* modify vllm and update readme
* run pre-commit
* remove dupilicated lines and refine code
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update param name
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refine code
* update readme
* refine code
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [release] update version (#6062)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] fix poc format
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix mem check;
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [feat] moehybrid support zerobubble;
* [fix] fix zerobubble pp for shardformer type input;
* [fix] fix require_grad & deallocate call;
* [fix] fix mem assert;
* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;
* [fix] fix zerobubble; support shardformer model type;
* [fix] fix test_pipeline_utils ci;
* [plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* hybrid support zbv
* fix
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] fix poc format
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [feat] update test; rm comments;
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix mem check;
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix mem assert;
* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
* [plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* hybrid support zbv
* fix
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
* [fix] fix mixtral policy;
* [fix] fix mixtral policy;
* [feat] support zbv in mixtral benchmark;
* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;
* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;
* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv
* [zero bubble] support zero (#6080)
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [fp8] add fp8 comm for low level zero
* [test] add zero fp8 test case
* [Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8
* Llama Shardformer Parity
* fix typo
* fix all reduce
* fix pytest failure
* fix reduce op and move function to fp8.py
* fix typo
* [FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
* [fp8]support all2all fp8 (#5953)
* support all2all fp8
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [fp8] add fp8 linear (#5967)
* [fp8] add fp8 linear
* [test] fix fp8 linear test condition
* [test] fix fp8 linear test condition
* [test] fix fp8 linear test condition
* [fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin
* [test] add fp8 hook test
* [fp8] fix fp8 linear compatibility
* fix (#5976)
* [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* implement communication hook for FSDP params all-gather
* added unit test for fp8 operators
* support fp8 communication in GeminiPlugin
* update training scripts to support fsdp and fp8 communication
* fixed some minor bugs observed in unit test
* add all_gather_into_tensor_flat_fp8
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add skip the test if torch < 2.2.0
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add skip the test if torch < 2.2.0
* add skip the test if torch < 2.2.0
* add fp8_comm flag
* rebase latest fp8 operators
* rebase latest fp8 operators
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [test ci]Feature/fp8 comm (#5981)
* fix
* fix
* fix
* [fp8] support gemini plugin (#5978)
* [fp8] refactor hook
* [fp8] support gemini plugin
* [example] add fp8 option for llama benchmark
* [fp8] use torch compile (torch >= 2.3.0) (#5979)
* [fp8] use torch compile (torch >= 2.4.0)
* [fp8] set use_fast_accum in linear
* [chore] formal version check
* [chore] fix sig
* [fp8]Moe support fp8 communication (#5977)
* fix
* support moe fp8
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
fix
fi
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model
* support fp8 comm for qwen2 model
* support fp8 comm for qwen2 model
* fp8
* fix
* bert and bloom
* chatglm and command
* gpt2,gptj,bert, falcon,blip2
* mistral,opy,sam,t5,vit,whisper
* fix
* fix
* fix
* [fp8] refactor fp8 linear with compile (#5993)
* [fp8] refactor fp8 linear with compile
* [fp8] fix linear test
* [fp8] fix linear test
* [fp8] support asynchronous FP8 communication (#5997)
* fix
* fix
* fix
* support async all2all
* support async op for all gather
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)
* [fp8] linear perf enhancement
* [fp8]update reduce-scatter test (#6002)
* fix
* fix
* fix
* fix
* [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
* [fp8] zero support fp8 linear. (#6006)
* fix
* fix
* fix
* zero fp8
* zero fp8
* Update requirements.txt
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
* [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update train_dpo.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update low_level_zero_plugin.py
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018)
* remove triton version
* remove torch 2.2
* remove torch 2.1
* debug
* remove 2.1 build tests
* require torch >=2.2
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [plugin] hotfix zero plugin (#6036)
* [plugin] hotfix zero plugin
* [plugin] hotfix zero plugin
* [Colossal-LLaMA] Refactor latest APIs (#6030)
* refactor latest code
* update api
* add dummy dataset
* update Readme
* add setup
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update files
* add PP support
* update arguments
* update argument
* reorg folder
* update version
* remove IB infor
* update utils
* update readme
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update save for zero
* update save
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add apex
* update
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add fused norm (#6038)
* [FP8] unsqueeze scale to make it compatible with torch.compile (#6040)
* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)
* fix bug in load_state_dict_into_model; format error msg
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update utils.py
to support checking missing_keys
* Update general_checkpoint_io.py
fix bug in missing_keys error message
* retrigger tests
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hotfix] Remove deprecated install (#6042)
* remove deprecated install
* remove unused folder
* [fp8] optimize all-gather (#6043)
* [fp8] optimize all-gather
* [fp8] fix all gather fp8 ring
* [fp8] enable compile
* [fp8] fix all gather fp8 ring
* [fp8] fix linear hook (#6046)
* [fp8] disable all_to_all_fp8 in intranode (#6045)
* enhance all_to_all_fp8 with internode comm control
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* disable some fp8 ops due to performance issue
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [release] update version (#6041)
* [release] update version
* [devops] update comp test
* [devops] update comp test debug
* [devops] debug comp test
* [devops] debug comp test
* [devops] debug comp test
* [devops] debug comp test
* [devops] debug comp test
* [Feature] Split cross-entropy computation in SP (#5959)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
* adapt chatglm, command-R, qwen
* debug
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
* add comments
* q1 index only once
* remove events to simplify stream sync
* simplify forward/backward logic
* 2d ring forward passed
* 2d ring backward passed
* fixes
* fix ring attn loss
* 2D ring backward + llama passed
* merge
* update logger
* fix typo
* rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* remove typos
* fixes
* support GPT
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins
* [example] add mixtral benchmark
* [moe] refine assertion and check
* [moe] fix mixtral & add more tests
* [moe] consider checking dp * sp group and moe_dp_group
* [mixtral] remove gate tp & add more tests
* [deepseek] fix tp & sp for deepseek
* [mixtral] minor fix
* [deepseek] add deepseek benchmark
* [fp8] hotfix backward hook (#6053)
* [fp8] hotfix backward hook
* [fp8] hotfix pipeline loss accumulation
* [doc] update sp doc (#6055)
* update sp doc
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix the sp
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the attn
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* [fp8] fix missing fp8_comm flag in mixtral (#6057)
* fix
* fix
* fix
* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)
* all_gather only internode, fix pytest
* fix cuda arch <89 compile pytest error
* fix pytest failure
* disable all_gather_into_tensor_flat_fp8
* fix fp8 format
* fix pytest
* fix conversations
* fix chunk tuple to list
* [doc] FP8 training and communication document (#6050)
* Add FP8 training and communication document
* add fp8 docstring for plugins
* fix typo
* fix typo
* fix
* fix
* [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)
* [ColossalEval] support for vllm (#6056)
* support vllm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* modify vllm and update readme
* run pre-commit
* remove dupilicated lines and refine code
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update param name
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refine code
* update readme
* refine code
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [release] update version (#6062)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] fix poc format
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix mem check;
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [feat] moehybrid support zerobubble;
* [fix] fix zerobubble pp for shardformer type input;
* [fix] fix require_grad & deallocate call;
* [fix] fix mem assert;
* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;
* [fix] fix zerobubble; support shardformer model type;
* [fix] fix test_pipeline_utils ci;
* [plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* hybrid support zbv
* fix
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] fix poc format
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [feat] update test; rm comments;
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix mem check;
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix mem assert;
* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
* [plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* [zerobubble]Support ZeroBubble Pipeline (#6034)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
* [feat] add dw test;
* [fix] fix weight not close;
* [update] update text;
* [feat] add test run_fwd_bwd automatic scheduling;
* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;
* [feat] add test for p & p grad;
* [feat] add comments for ZBV func;
* [fix] rm useless assign and comments;
* [fix] fix ci test; add pytest;
* [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
* [feat] add apply v_schedule graph; p & p.grad assert err exist;
* [fix] update
* [feat] fix ci; add assert;
* [feat] fix poc format
* [feat] fix func name & ci; add comments;
* [fix] fix poc test; add comments in poc;
* [feat] add optim backward_b_by_grad
* [feat] fix optimizer bwd b & w; support return accum loss & output
* [feat] add fwd_bwd_step, run_fwd_only;
* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
* [fix] fix communication_map;
* [feat] update test; rm comments;
* [fix] rm zbv in hybridplugin
* [fix] fix optim bwd;
* [fix] fix optim bwd;
* [fix] rm output.data after send fwd;
* [fix] fix bwd step if condition; remove useless comments and format info;
* [fix] fix detach output & release output;
* [fix] rm requir_grad for output;
* [fix] fix requir grad position and detach position and input&output local buffer append position;
* [feat] add memory assertation;
* [fix] fix mem check;
* [fix] mem assertation'
* [fix] fix mem assertation
* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;
* [fix] fix model zoo import;
* [fix] fix redundant detach & clone; add buffer assertation in the end;
* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;
* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
* [fix] add testcase with microbatch 4;
* hybrid support zbv
* fix
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update zero_bubble_pp.py
* fix
* fix-ci
* fix
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
* fix
* fix
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* fix
* fix
* fix
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* zbv support zero
* fix
* fix
* fix
---------
Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;
* [feat] Linear1D_COL/ROW support zbv WeightGradStore;
* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;
* [fix] fix test case; moe error in second iter
* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;
* [fix] debug zbv llama test;
* [fix] rm use_zbv flag in Shardconfig; rm debug info;
* [fix] add & fix llama test
* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);
* [fix\ fix fail case test_shard_llama
* [fix] fix test_shard_llama
* [fix] fix llama modeling policy;
* [fix] fix test_shard_llama ci;
* [fix] fix test zerobubble
* [fix] fix handle name; rm useless comments;
* [fix] fix send recv signature;
* [fix] fix comment in llama & benchmark
* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore
* [fix] fix linear (no tp) ops func name;
* [feat] support zbv in mixtral benchmark; (#6083)
* [feat] support zbv in mixtral benchmark;
* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;
* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;
* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv
* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;
* [feat] Linear1D_COL/ROW support zbv WeightGradStore;
* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;
* [fix] fix test case; moe error in second iter
* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;
* [fix] debug zbv llama test;
* [fix] rm use_zbv flag in Shardconfig; rm debug info;
* [fix] add & fix llama test
* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);
* [fix\ fix fail case test_shard_llama
* [fix] fix test_shard_llama
* [fix] fix llama modeling policy;
* [fix] fix test_shard_llama ci;
* [fix] fix test zerobubble
* [fix] fix handle name; rm useless comments;
* [fix] fix send recv signature;
* [fix] fix comment in llama & benchmark
* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore
* [fix] fix linear (no tp) ops func name;
* [fix] fix fp8 args in HybridParallel
* [fix] fix hybridparall use_fp8 config
* [fix] fix use_fp8 flag
* [fix] fix model zoo init
* [feat] support no_tp Linear for sharderformer.llama
* [fix] fix zbv llama pp4
* [fix] fix send_tensor_metadata & send_grad_metadata;
* [feat] fix testcase;
* [feat] support mixtral policy with zbv tp_Linear & non_tp_Linear
* [feat] update mixtral policy & bert policy for zerobubble
* [fix] fix p2p error in zbv
* [fix] fix attn
* [fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;
* [fix] fix zbv wait_handle
* [fix] rm debug info; update llama policy; update wait handle
* [fix] fix test_lora
* [fix] fix test_lora in llama policy
* [fix] fix wait handle in run_fwd_bwd
* [fix] remove debug info;
* [fix] rm unused comments
* [fix] fix fp8 overlap code
* [fix] fix yml file & v_schedule comments
* [fix] rm fwd only meta cache comments;
---------
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
Shardformer is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
🔨 Usage
Quick Start
The sample API usage is given below(If you enable the use of flash attention, please install flash_attn. In addition, xformers's cutlass_op provide a supplementary optimization):
fromcolossalai.shardformerimportShardConfig,ShardFormerfromtransformersimportBertForMaskedLMimportcolossalai# launch colossalaicolossalai.launch_from_torch()# create modelconfig=BertConfig.from_pretrained('bert-base-uncased')model=BertForMaskedLM.from_pretrained('bert-base-uncased',config=config)# create huggingface model as normalshard_config=ShardConfig(tensor_parallel_process_group=tp_group,pipeline_stage_manager=stage_manager,enable_tensor_parallelism=True,enable_fused_normalization=True,enable_flash_attention=True,enable_jit_fused=True,enable_sequence_parallelism=True,enable_sequence_overlap=True)shard_former=ShardFormer(shard_config=shard_config)sharded_model,shared_params=shard_former.optimize(model).to('cuda')# do everything like normal...
Following are the description ShardConfig's arguments:
tensor_parallel_process_group: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.
pipeline_stage_manager: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
enable_tensor_parallelism: Whether to use tensor parallelism. Defaults to True.
enable_fused_normalization: Whether to use fused layernorm. Defaults to False.
enable_flash_attention: Whether to switch on flash attention. Defaults to False.
enable_jit_fused: Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization: Whether to turn on all optimization tools including fused normalization, flash attention, JIT fused operators, sequence parallelism and sequence overlap. Defaults to False.
extra_kwargs: A dict to store extra kwargs for ShardFormer.
Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in API Design.
fromcolossalai.shardformerimportPolicyclassMyPolicy(Policy):# implement your own policy...# init model and shard former...# use customized policy to shard modelmy_policy=MyPolicy()shard_former.optimize(model,my_policy)
🗺 Roadmap
We will follow this roadmap to develop Shardformer:
API Design
API Implementation
Unit Testing
Policy Implementation
model
tensor parallel
pipeline parallel
lazy initialization
xformer
flash attn2
jit fused operator
fused layernorm
sequence parallel
overlap
bert
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
t5
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[ ]
[ ]
llama V1/V2
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[ ]
[ ]
gpt2
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
opt
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[ ]
[ ]
bloom
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
chatglm2
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
[√]
vit
[√]
[√]
[ ]
[√]
[√]
[√]
[√]
[ ]
[ ]
whisper
[√]
[√]
[√]
[√]
[√]
[ ]
[√]
[ ]
[ ]
sam
[√]
[ ]
[ ]
[√]
[√]
[√]
[√]
[ ]
[ ]
blip2
[√]
[ ]
[ ]
[√]
[√]
[√]
[√]
[ ]
[ ]
falcon
[√]
[√]
[√]
[√]
[√]
[ ]
[√]
[ ]
[ ]
roberta
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
albert
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
ernie
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
gpt-neo
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
gpt-j
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
beit
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
swin
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
swin V2
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
qwen
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
[ ]
mistral
[√]
[ ]
[ ]
[√]
[√]
[√]
[√]
[ ]
[ ]
💡 API Design
We will discuss the major components of ShardFormer below to help you better understand how things work.
This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation.
Please refer to the code for more details.
Distributed Modules
ShardFormer replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new forward function to execute distributed computation.
Each distributed module implements its from_native_module static method to convert the PyTorch module to its corresponding distributed module.
classParallelModule(torch.nn.Module):@abstractmethoddeffrom_native_module(module:torch.nn.Module,process_group:Union[ProcessGroup,Tuple[ProcessGroup]])->ParallelModule"""
Convert a native module to a parallelized
Examples:
```python
# replace module
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
"""
Shard Config
ShardConfig is a simple data class to tell ShardFormer how sharding will be performed.
@dataclassclassShardConfig:tensor_parallel_process_group:ProcessGroup=Noneenable_fused_normalization:bool=False...# Some possible future config fieldstensor_parallel_mode:Choice['1d','2d','2.5d','3d']# support different tensor parallel modeuse_flash_attention:bool# whether to use flash attention to speed up attentionextra_kwargs:Dict[str,Any]# extra kwargs for the shardformer
Policy
The Policy class describes how to handle the model sharding.
It is merely a description, the actual sharding will be performed by ModelSharder.
We abstract the policy into four stages:
Preprocessing: call Policy.preprocess to do some prior work before sharding, for example, resizing the embedding
Providing ModulePolicyDescription: call Policy.module_policy to get a bunch of ModulePolicyDescription to tell ModelSharder how the submodules's attributes, child parameters, and deeper submodules will be substituted.
Postprocessing: call Policy.postprocess to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
@dataclassclassModulePolicyDescription:r"""
Describe how the attributes and parameters will be transformed in a policy.
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""attribute_replacement:Dict[str,Any]=Noneparam_replacement:List[Callable]=Nonesub_module_replacement:List[SubModuleReplacementDescription]=Nonemethod_replacement:Dict[str,Callable]=None@dataclassclassSubModuleReplacementDescription:r"""
Describe how a submodule will be replaced
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""suffix:strtarget_module:ParallelModulekwargs:Dict[str,Any]=Noneignore_if_not_exist:bool=FalseclassPolicy(ABC):r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""def__init__(self)self.model=Nonedefset_model(self,model:nn.Module)->None:"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
"""self.model=modeldefset_shard_config(self,shard_config:ShardConfig)->None:r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""self.shard_config=shard_configself.config_sanity_check()@abstractmethoddefpreprocess(self)->nn.Module:"""
Perform some preprocessing on the model, such as resizing the embedding size
"""...@abstractmethoddefmodule_policy(self)->Dict[Union[str,nn.Module],ModulePolicyDescription]:"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
"""...@abstractmethodsdefpostprocess(self)->nn.Module:"""
Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
"""...
Model Sharder
ModelSharder is the class in charge of sharding the model based on the given policy.
classModelSharder:def__init__(self,model:torch.nn.Module,shard_config:ShardConfig,Policy:ShardPolicy=None):#TODO: input is a cls or a obj...defshard(self)->None:"""
Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""...defreplace_module(self)->None:"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
"""...
User-facing API
We only expose a limited number of APIs to the user to keep their user experience simple and clean.
classShardFormer:"""
Parallelize model based on the given config and policy
Example:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
model, shared_params = shard_former.optimize(org_model)
"""def__init__(self,shard_config:ShardConfig):"""
Do two things:
1. Create a distribute coordinator
2. serve as a store for shard config
"""self.shard_config=shard_configself.coordinator=DistCoordinator()defoptimize(self,model:nn.Module,policy:Policy=None)->Tuple[nn.Module,List[Dict[int,Tensor]]]:r"""
This method will optimize the model based on the given policy.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
Returns: the sharded model and the shared parameters
"""sharder=ModelSharder(model=model,shard_config=self.shard_config,policy=policy)shared_params=sharder.shard()returnmodel,shared_params
⌨️ Development Notes
Add New Policy to Shardformer
This section serves as the guideline for writing new policies and register them into shardformer.
Step 1. Write your own model policy
You can create a new file in the colossalai/shardformer/policies folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as transformers should be imported only in the function body when needed.
Please follow the following protocols when writing your policy:
You have to make a clear decision what you want to replace exactly in the original PyTorch module
Use ModulePolicyDescription.attribute_replacement to replace the module attributes
Use ModulePolicyDescription.param_replacement to replace the module parameters
Use ModulePolicyDescription.sub_module_replacement to replace the submodules completely. The target module should implement the from_native_module for the replacement.
Use ModulePolicyDescription.method_replacement to replace the module methods. These replacement methods should be put in the shardformer/modeling/<model-name>.py.
You can implement the ParallelModule for primitive modules in the shardformer/layer/<model-name>.py file. Primitive modules refer to modules which are not composed of other modules. For example, the torch.nn.Linear module is a primitive module while modules such as BertEncoder module in the transformers library is a composite module. Primitive modules do not nested inner nn.Module members. For composite modules, you should consider using ModulePolicyDescription to implement your replacement.
ParallelModule is meant to be used in two ways: ParallelModule.from_native_module to convert native PyTorch module to the ParallelModule and ParallelModule(...) to instantiate the module directly just like a normal PyTorch module. ParallelModule should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the ModulePolicyDescription.sub_module_replacement and there is no weight sharding in your module, you can just implement the from_native_module method without inheriting the ParallelModule like colossalai/shardformer/layer/normalization.py.
Do not import any file in the colossalai/shardformer/policies and colossalai/shardformer/modeling to avoid unwanted import error. For example, a file in these folders accidentally imports transformers library at the top of the file, then the user will have to install transformers library even if they do not use this file. Any file in the modeling folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the ShardFormer module.
Try to keep your import statement on third-party libraries such as transformers within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.
Step 2. Register your policy to the autopolicy
Next, you need to register your policy in the colossalai/shardformer/policies/autopolicy.py file.
For example, if we register the policy for the BERT model, we just add a key-value in the _POLICY_LIST dictionary. The key if the qualname of the model object (you can get it by model.__class__.__qualname__). The value is a PolicyLocation object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as transformers) which we do not want to import when we do not use the policy.
How to support those models in huggingface model hub but not in the transformers library
There are two cases:
the modeling file is in the transformers library but the model weight is not in the transformers library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the transformers library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
the modeling file is not in the transformers library, such as the "THUDM/chatglm2-6b".
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the shardformer.
Unlike llama which is in transformers library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
When using such models, AutoModel is supported as usual. The policy will be automatically loaded by the autopolicy.
Write Your Unit Testing
This section serves as the guideline for testing the shardformer module.
Step 1. Add your model to the model zoo in the test kits.
Add your model to the tests/kit/model_zoo file. This allows you to define test-related components for this model. You can take tests/kit/model_zoo/transformers/llama.py as an example for reference.
Step 2. Write your unit testing for the model
Next, implement your unit test in the tests/test_shardformer folder. Please refer to other similar tests for style consistency.
Step 3. Execute your test
When you run tests locally, you should run tests for both your newly-added test file and the whole shardformer module tests.
# test for your own test file
pytest tests/test_shardformer/test_model/<your-file>.py
# test for the whole shardformer module
pytest tests/test_shardformer
📊 Benchmarking
System Performance
We conducted benchmark tests to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
N_CTX
org_model
shard_model
256
11.2ms
17.2ms
512
9.8ms
19.5ms
1024
19.6ms
18.9ms
2048
46.6ms
30.8ms
4096
160.5ms
90.4ms
In the case of using 4 GPUs, the training times are as follows.
N_CTX
org_model
shard_model
256
10.0ms
21.1ms
512
11.5ms
20.2ms
1024
22.1ms
20.6ms
2048
46.9ms
24.8ms
4096
160.4ms
68.0ms
As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
Convergence
To validate that training the model using shardformers does not impact its convergence. We fine-tuned the BERT model using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.