* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test
* merge tests of PP/DP/TP combinations into one test file
* fix bug when sync grad for dp in HybridPlugin
* update supported precisions for 3DPlugin/fix bug when shifting tp_degree
* improve the passing of lazy_init
* modify lazy_init/use sync_shared_params
* refactor tests
* refactor bloom model
* finish policy tests
* refactor tests
* fix test pure pipeline
* remove test pipeline and cutdown launch process
* refactor tests
* refactor bloom model
* finish policy tests
* refactor tests
* fix test pure pipeline
* remove test pipeline and cutdown launch process
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* support base vit pipeline
* support vit downstream model
* fix vit shard test
* modify hidden states return type
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
* complete policy for T5Model & T5ForConditionalGeneration
* modify function signature in forwards
* add forward for T5model
* add forward for T5ForConditionalGeneration
* fix a bug
* fix hidden_states transporting in decoder
* fix the passing of encoder_outputs
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* fixed version
* fixed version
* pure pipeline
* opt forward and test
* pause
* finish opt model pipeline
* finish opt pipeline
* opt forward and test
* pause
* finish opt model pipeline
* finish opt pipeline
* fix opt
* set transformers version
* refactor the test pipeline
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* finish some bert models
* finish all bert models
* finish bert tests
* fix bugs
* fix bugs
* fix test pipeline
* fix data gen for qa
* update the set pipeline forward
* shared params
* fix bugs
* * fix typehint & docstring in sharder.py
* * update pipeline forward for GPT2Model
* * add test for pipeline forward of GPT2Model
* * add cache cleaning in gpt2 test
* * change assert to raise command
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* finish bloom model
* test shard gpt2
* clear cache
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* [shardformer] support lazy init
* [shardformer] linear support lazy init
* [shardformer] embedding support lazy init
* [shardformer] norm support lazy init
* [shardformer] fused linear support lazy init
* [test] update shardformer test layer
* [test] shardformer with lazy init fit ddp
* [lazy] hotfix deepcopy of param
* [shardformer] fix bert policy and update test
* [shardformer] fix bloom policy and update test
* [shardformer] fix opt policy and update test
* [shardformer] fix t5 policy and update test
* [shardformer] fix gpt2 policy and update test
* [shardformer] fix llama policy and update test
* add pipeline policy and bert forward to be done
* add bertmodel pipeline forward and make tests
* add Bert_Policy and test for policy
* update formatting
* update formatting
* update the code
* fix bugs
* fix name confilt
* add bloom model and policy ,revise the base class of policy
* revise
* revision
* add bert_for_pretraining
* add bert_for_pretraining forward and policy
* fix typos
* cancel warning
* change the imediate output to default dict
* change the default output of get_shared_params
* add pipeline policy and bert forward to be done
* add bertmodel pipeline forward and make tests
* add Bert_Policy and test for policy
* update formatting
* update formatting
* update the code
* fix bugs
* fix name confilt
* add bloom model and policy ,revise the base class of policy
* revise
* revision
* add bert_for_pretraining
* add pipeline policy and bert forward to be done
* add bertmodel pipeline forward and make tests
* add Bert_Policy and test for policy
* update formatting
* update formatting
* update the code
* fix bugs
* fix name confilt
* add pipeline policy and bert forward to be done
* add bertmodel pipeline forward and make tests
* add Bert_Policy and test for policy
* update formatting
* update formatting
* update the code
* fix bugs
* fix name confilt
* refactor low level zero
* fix zero2 and support cpu offload
* avg gradient and modify unit test
* refactor grad store, support layer drop
* refactor bucket store, support grad accumulation
* fix and update unit test of zero and ddp
* compatible with tp, ga and unit test
* fix memory leak and polish
* add zero layer drop unittest
* polish code
* fix import err in unit test
* support diffenert comm dtype, modify docstring style
* polish code
* test padding and fix
* fix unit test of low level zero
* fix pad recording in bucket store
* support some models
* polish
* sharded optimizer checkpoint for gemini plugin
* modify test to reduce testing time
* update doc
* fix bug when keep_gatherd is true under GeminiPlugin
* first v of vit shardformer
* keep vit
* update
* vit shard add vitattention vitlayer
* update num head shard para
* finish test for vit
* add new_model_class & postprocess
* add vit readme
* delete old files & fix the conflict
* fix sth
* add layernorm to bert
* add layernorm test
* add layernorm test with load state dict
* add use_mixedfusedLN in shard config
* refactor policy to support fused_layernorm
* [bf16] add bf16 support for fused adam (#3844)
* [bf16] fused adam kernel support bf16
* [test] update fused adam kernel test
* [test] update fused adam test
* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)
* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)
* [bf16] add mixed precision mixin
* [bf16] low level zero optim support bf16
* [text] update low level zero test
* [text] fix low level zero grad acc test
* [bf16] add bf16 support for gemini (#3872)
* [bf16] gemini support bf16
* [test] update gemini bf16 test
* [doc] update gemini docstring
* [bf16] add bf16 support for plugins (#3877)
* [bf16] add bf16 support for legacy zero (#3879)
* [zero] init context support bf16
* [zero] legacy zero support bf16
* [test] add zero bf16 test
* [doc] add bf16 related docstring for legacy zero
* [plugin] torch ddp plugin add save sharded model
* [test] fix torch ddp ckpt io test
* [test] fix torch ddp ckpt io test
* [test] fix low level zero plugin test
* [test] fix low level zero plugin test
* [test] add debug info
* [test] add debug info
* [test] add debug info
* [test] add debug info
* [test] add debug info
* [test] fix low level zero plugin test
* [test] fix low level zero plugin test
* [test] remove debug info
* [test] fix flop tensor test
* [test] fix autochunk test
* [test] fix lazyinit test
* [devops] update torch version of CI
* [devops] enable testmon
* [devops] fix ci
* [devops] fix ci
* [test] fix checkpoint io test
* [test] fix cluster test
* [test] fix timm test
* [devops] fix ci
* [devops] fix ci
* [devops] fix ci
* [devops] fix ci
* [devops] force sync to test ci
* [test] skip fsdp test