Commit Graph

964 Commits (b8e770c832276d212673fe3d7f41a6ce2ee40858)

Author SHA1 Message Date
FoolPlayer 9ee4ebea83 [shardformer] support whisper (#4212)
* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme
2023-08-15 23:25:14 +08:00
FoolPlayer dd2bf02679 [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code
2023-08-15 23:25:14 +08:00
Baizhou Zhang 0ceec8f9a9 [pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)
* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test

* merge tests of PP/DP/TP combinations into one test file

* fix bug when sync grad for dp in HybridPlugin

* update supported precisions for 3DPlugin/fix bug when shifting tp_degree

* improve the passing of lazy_init

* modify lazy_init/use sync_shared_params
2023-08-15 23:25:14 +08:00
Jianghai f13954cd58 [pipeline] refactor test pipeline and remove useless utils in pipeline (#4324)
* refactor tests

* refactor bloom model

* finish policy tests

* refactor tests

* fix test pure pipeline

* remove test pipeline and cutdown launch process

* refactor tests

* refactor bloom model

* finish policy tests

* refactor tests

* fix test pure pipeline

* remove test pipeline and cutdown launch process
2023-08-15 23:25:14 +08:00
LuGY d3c6cd66f3 [pipeline] add unit test for 1f1b (#4303)
* add unit test for 1f1b

* polish code

* polish code and update ut version

* fix
2023-08-15 23:25:14 +08:00
Baizhou Zhang da3cef27ad [pipeline] fix return_dict/fix pure_pipeline_test (#4331) 2023-08-15 23:25:14 +08:00
Hongxin Liu 411cf1d2db [hotfix] fix gemini and zero test (#4333)
* [hotfix] fix gemini and zero test

* [hotfix] fix lazy init test

* [hotfix] fix lazy init test
2023-08-15 23:25:14 +08:00
Hongxin Liu 261eab02fb [plugin] add 3d parallel plugin (#4295)
* [amp] add mixed precision optimizer

* [plugin] add 3d parallel plugin

* [booster] support pipeline

* [plugin] 3d parallel plugin support clip grad norm

* [shardformer] fix sharder and add plugin test

* [plugin] rename 3d parallel plugin

* [ci] support testmon core pkg change detection (#4305)

* [hotfix] debug testmon

* [hotfix] fix llama

* [hotfix] fix p2p bugs

* [hotfix] fix requirements
2023-08-15 23:25:14 +08:00
FoolPlayer b3f5d7a3ba [shardformer] support pipeline base vit model (#4284)
* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* support base vit pipeline

* support vit downstream model

* fix vit shard test

* modify hidden states return type

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
2023-08-15 23:25:14 +08:00
Baizhou Zhang 083d7da33d [pipeline] add pipeline support for all T5 models (#4310)
* complete policy for T5Model & T5ForConditionalGeneration

* modify function signature in forwards

* add forward for T5model

* add forward for T5ForConditionalGeneration

* fix a bug

* fix hidden_states transporting in decoder

* fix the passing of encoder_outputs
2023-08-15 23:25:14 +08:00
Jianghai d0807122e2 [pipeline] test pure pipeline process using llama (#4218)
* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision

* add pure pipeline test

* fixed version

* fixed version

* pure pipeline
2023-08-15 23:25:14 +08:00
Baizhou Zhang 36e546b2cc [pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)
* modify t5 policy & add test

* pipeline stage distribution for t5

* complete t5 base policy

* t5 stack: halfway

* modify gpt2 pipeline test

* complete pipeline forward for T5Stack/T5EncoderModel

* fix docstring

* move t5 util tests to test_pipeline
2023-08-15 23:25:14 +08:00
Jianghai d8408d185c [pipeline] OPT model pipeline (#4258)
* opt forward and test

* pause

* finish opt model pipeline

* finish opt pipeline

* opt forward and test

* pause

* finish opt model pipeline

* finish opt pipeline

* fix opt

* set transformers version

* refactor the test pipeline
2023-08-15 23:25:14 +08:00
Hongxin Liu d921ce8391 [shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
2023-08-15 23:25:14 +08:00
Baizhou Zhang 2a2eacfaf1 [pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)
* change for transformers loggers

* add forward for GPT2ForQuestionAnswering

* fix assert

* fix torchrec test
2023-08-15 23:25:14 +08:00
Jianghai d9be0472ef [bugs] hot fix some testing bugs for new models (#4268)
* hot fix

* hot fx tracer
2023-08-15 23:25:14 +08:00
Jianghai 34f0e34a4c [pipeline] finish bloom models pipeline and tests (#4223)
* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* finish bloom model

* test shard gpt2

* clear cache

* support all bloom models

* add bloom models policies

* finish bloom pipeline and tests

* add set pipeline

* finish bloom
2023-08-15 23:25:14 +08:00
Jianghai e7cc62d735 [pipeline] All bert models (#4233)
* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision

* add pure pipeline test

* finish some bert models

* finish all bert models

* finish bert tests

* fix bugs

* fix bugs

* fix test pipeline

* fix data gen for qa

* update the set pipeline forward

* shared params

* fix bugs
2023-08-15 23:25:14 +08:00
Baizhou Zhang a14d352088 [pipeline] add pipeline forward for variants of gpt2 (#4238)
* add forward for GPTLMHeadModel

* add test for gpt_lm

* arranging get_held_layers method

* arrange forward replacement

* add forward for GPT2ForTokenClassification

* add forward for GPT2ForSequenceClassification

* fix test_shard_gpt2.py

* add GPT2DoubleHeadsmodel & fix bugs

* add id checking in get_shared_params
2023-08-15 23:25:14 +08:00
Baizhou Zhang 208ac8f2ba [pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)
* * fix typehint & docstring in sharder.py

* * update pipeline forward for GPT2Model

* * add test for pipeline forward of GPT2Model

* * add cache cleaning in gpt2 test

* * change assert to raise command
2023-08-15 23:25:14 +08:00
Jianghai 37d22f6878 [pipeline] add bloom model pipeline (#4210)
* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* finish bloom model

* test shard gpt2

* clear cache
2023-08-15 23:25:14 +08:00
Jianghai 31bcf867ae [pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)
* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision
2023-08-15 23:25:14 +08:00
Jianghai 1622031058 [pipeline] Llama pipeline (#4205)
* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt
2023-08-15 23:25:14 +08:00
Jianghai 1094e0f0d3 [pipeline] Bert pipeline for shardformer and its tests (#4197)
* add pipeline forward

* complete pipeline forward check

* fix bert forward without pipeline

* fix comments

* discard useless line

* add todo

* clean prints

* fix distribute layers
2023-08-15 23:25:14 +08:00
Hongxin Liu 890774b2fb [shardformer] support lazy init (#4202)
* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
2023-08-15 23:25:14 +08:00
Jianghai f3bcc292c8 [pipeline] move bert related pipeline components to shardformer (#4187)
* move bert related pipeline components to shardformer

* fix bugs

* revision

* fix bert model tests

* fix bert_lm_head model tests

* fix tests

* fix tests

* done checks

* skip bloom
2023-08-15 23:25:14 +08:00
Jianghai c5ea728016 [pipeline] add bert_for_pretraining bert_lmhead forward and policy (#4172)
* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params
2023-08-15 23:25:14 +08:00
ver217 5fc60a3a04 [test] add shard util tests 2023-08-15 23:25:14 +08:00
ver217 2d6cc07feb [test] update shardformer tests 2023-08-15 23:25:14 +08:00
Jianghai 90a65ea682 [pipeline] build bloom model and policy , revise the base class of policy (#4161)
* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining
2023-08-15 23:25:14 +08:00
Jianghai c552cefa93 [pipeline]add pipeline policy and bert forward (#4130)
* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt
2023-08-15 23:25:14 +08:00
Hongxin Liu 5c897ddb94 [pipeline] add stage manager (#4093)
* [pipeline] add stage manager

* [test] add pipeline stage manager test

* [pipeline] add docstring for stage manager
2023-08-15 23:25:14 +08:00
Jianghai e8e7e49243 [pipeline]add pipeline policy and bert forward (#4130)
* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt
2023-08-15 23:25:14 +08:00
Hongxin Liu f51ce1bc8e [pipeline] refactor 1f1b schedule (#4115)
* [api] update optimizer wrapper to fit pipeline

* [pipeline] add base schedule

* [pipeline] add 1f1b schedule

* [test] add pipeline schedule utils test

* [pipeline] fix import
2023-08-15 23:25:14 +08:00
Hongxin Liu 45fdc9b42c [pipeline] implement p2p communication (#4100)
* [pipeline] add p2p communication

* [test] add p2p communication test

* [test] add rerun decorator

* [test] rename to avoid conflict
2023-08-15 23:25:14 +08:00
Hongxin Liu 422544222f [pipeline] add stage manager (#4093)
* [pipeline] add stage manager

* [test] add pipeline stage manager test

* [pipeline] add docstring for stage manager
2023-08-15 23:25:14 +08:00
Hongxin Liu 5e1a9d48dd [cluster] add process group mesh (#4039)
* [cluster] add process group mesh

* [test] add process group mesh test

* force sync
2023-08-15 23:25:14 +08:00
LuGY d86ddd9b29
[hotfix] fix unsafe async comm in zero (#4404)
* improve stablility of zero

* fix wrong index

* add record stream
2023-08-11 15:09:24 +08:00
flybird1111 458ae331ad
[kernel] updated unittests for coloattention (#4389)
Updated coloattention tests of checking outputs and gradients
2023-08-09 14:24:45 +08:00
flybird1111 38b792aab2
[coloattention] fix import error (#4380)
fixed an import error
2023-08-04 16:28:41 +08:00
flybird1111 25c57b9fb4
[fix] coloattention support flash attention 2 (#4347)
Improved ColoAttention interface to support flash attention 2. Solved #4322
2023-08-04 13:46:22 +08:00
Hongxin Liu 16bf4c0221
[test] remove useless tests (#4359)
* [test] remove legacy zero test

* [test] remove lazy distribute test

* [test] remove outdated checkpoint io
2023-08-01 18:52:14 +08:00
LuGY 1a49a5ea00 [zero] support shard optimizer state dict of zero (#4194)
* support shard optimizer of zero

* polish code

* support sync grad manually
2023-07-31 22:13:29 +08:00
LuGY dd7cc58299 [zero] add state dict for low level zero (#4179)
* add state dict for zero

* fix unit test

* polish
2023-07-31 22:13:29 +08:00
LuGY c668801d36 [zero] allow passing process group to zero12 (#4153)
* allow passing process group to zero12

* union tp-zero and normal-zero

* polish code
2023-07-31 22:13:29 +08:00
LuGY 79cf1b5f33 [zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin

* polish

* polish
2023-07-31 22:13:29 +08:00
LuGY c6ab96983a [zero] refactor low level zero for shard evenly (#4030)
* refactor low level zero

* fix zero2 and support cpu offload

* avg gradient and modify unit test

* refactor grad store, support layer drop

* refactor bucket store, support grad accumulation

* fix and update unit test of zero and ddp

* compatible with tp, ga and unit test

* fix memory leak and polish

* add zero layer drop unittest

* polish code

* fix import err in unit test

* support diffenert comm dtype, modify docstring style

* polish code

* test padding and fix

* fix unit test of low level zero

* fix pad recording in bucket store

* support some models

* polish
2023-07-31 22:13:29 +08:00
Baizhou Zhang c6f6005990
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin

* modify test to reduce testing time

* update doc

* fix bug when keep_gatherd is true under GeminiPlugin
2023-07-21 14:39:01 +08:00
Hongxin Liu fc5cef2c79
[lazy] support init on cuda (#4269)
* [lazy] support init on cuda

* [test] update lazy init test

* [test] fix transformer version
2023-07-19 16:43:01 +08:00
Cuiqing Li 4b977541a8
[Kernels] added triton-implemented of self attention for colossal-ai (#4241)
* added softmax kernel

* added qkv_kernel

* added ops

* adding tests

* upload tets

* fix tests

* debugging

* debugging tests

* debugging

* added

* fixed errors

* added softmax kernel

* clean codes

* added tests

* update tests

* update tests

* added attention

* add

* fixed pytest checking

* add cuda check

* fix cuda version

* fix typo
2023-07-18 23:53:38 +08:00
Baizhou Zhang 58913441a1
Next commit [checkpointio] Unsharded Optimizer Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin

* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
2023-07-07 16:33:06 +08:00
github-actions[bot] c77b3b19be
[format] applied code formatting on changed files in pull request 4152 (#4157)
Co-authored-by: github-actions <github-actions@github.com>
2023-07-04 16:07:47 +08:00
Frank Lee 1fb0d95df0 [shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable

* polish code
2023-07-04 16:05:01 +08:00
Frank Lee 74257cb446 [shardformer] refactored some doc and api (#4137)
* [shardformer] refactored some doc and api

* polish code
2023-07-04 16:05:01 +08:00
Frank Lee ae035d305d [shardformer] added embedding gradient check (#4124) 2023-07-04 16:05:01 +08:00
Frank Lee 6a88bae4ec [shardformer] integrate with data parallelism (#4103) 2023-07-04 16:05:01 +08:00
Frank Lee f3b6aaa6b7 [shardformer] supported fused normalization (#4112) 2023-07-04 16:05:01 +08:00
Frank Lee b1c2901530 [shardformer] supported bloom model (#4098) 2023-07-04 16:05:01 +08:00
Kun Lin 8af29ee47a [shardformer] support vision transformer (#4096)
* first v of vit shardformer

* keep vit

* update

* vit shard add vitattention vitlayer

* update num head shard para

* finish test for vit

* add new_model_class & postprocess

* add vit readme

* delete old files & fix the conflict

* fix sth
2023-07-04 16:05:01 +08:00
jiangmingyan ac80937138 [shardformer] shardformer support opt models (#4091)
* [shardformer] shardformer support opt models

* [shardformer] shardformer support opt models, fix

* [shardformer] shardformer support opt models, fix

* [shardformer] shardformer support opt models, fix
2023-07-04 16:05:01 +08:00
Frank Lee d33a44e8c3 [shardformer] refactored layernorm (#4086) 2023-07-04 16:05:01 +08:00
Frank Lee c4b1b65931 [test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change

* polish code
2023-07-04 16:05:01 +08:00
FoolPlayer 92f6791095 [shardformer] Add layernorm (#4072)
* add layernorm to bert

* add layernorm test

* add layernorm test with load state dict

* add use_mixedfusedLN in shard config

* refactor policy to support fused_layernorm
2023-07-04 16:05:01 +08:00
Frank Lee 70c58cfd4f [shardformer] supported fused qkv checkpoint (#4073) 2023-07-04 16:05:01 +08:00
FoolPlayer 0803a61412 [shardformer] add linearconv1d test (#4067)
* add linearconv1d test

* add linearconv1d test
2023-07-04 16:05:01 +08:00
Frank Lee 8eb09a4c69 [shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading

* polish code
2023-07-04 16:05:01 +08:00
FoolPlayer 7740c55c55 support kit use for bert/gpt test (#4055)
* support kit use for bert test

* support kit test for gpt2
2023-07-04 16:05:01 +08:00
Frank Lee f22ddacef0 [shardformer] refactored the shardformer layer structure (#4053) 2023-07-04 16:05:01 +08:00
Frank Lee 58df720570 [shardformer] adapted T5 and LLaMa test to use kit (#4049)
* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
2023-07-04 16:05:01 +08:00
FoolPlayer 4021b9a8a2 [shardformer] add gpt2 test and layer class refactor (#4041)
* add gpt2 test and layer class refactor

* add dropout in gpt2 policy
2023-07-04 16:05:01 +08:00
Frank Lee d857f3dbba [shardformer] supported T5 and its variants (#4045) 2023-07-04 16:05:01 +08:00
Frank Lee c1d5453e9f [shardformer] adapted llama to the new API (#4036) 2023-07-04 16:05:01 +08:00
FoolPlayer 74d176c8d8 [shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api

* remove comment line
2023-07-04 16:05:01 +08:00
FoolPlayer 507c0ad368 add vocabembedding layer 2023-07-04 16:05:01 +08:00
Frank Lee 3893fa1a8d [shardformer] refactored embedding and dropout to parallel module (#4013)
* [shardformer] refactored embedding and dropout to parallel module

* polish code
2023-07-04 16:05:01 +08:00
FoolPlayer dfca9678fa integrate with dist layer (#4011) 2023-07-04 16:05:01 +08:00
Frank Lee 015af592f8 [shardformer] integrated linear 1D with dtensor (#3996)
* [shardformer] integrated linear 1D with dtensor

* polish code
2023-07-04 16:05:01 +08:00
Frank Lee 611971248c [device] support init device mesh from process group (#3990) 2023-07-04 16:05:01 +08:00
FoolPlayer f7774ec0f3 [Shardformer] Downstream bert (#3979)
* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage

* add downstream model of bert

* remove unused code
2023-07-04 16:05:01 +08:00
wukong1992 c1c672d0f0 [shardformer] shardformer support t5 model (#3994)
test t5
2023-07-04 16:05:01 +08:00
wukong1992 6b30dfb7ce [shardformer] support llama model using shardformer (#3969)
adjust layer attr
2023-07-04 16:05:01 +08:00
FoolPlayer a73130482d [shardformer] Unit test (#3928)
* fix bug in slicer, add slicer unit test

* add dropout test

* use pid as dropout seed

* updata dropout test with local pattern

* ad todo
2023-07-04 16:05:01 +08:00
FoolPlayer f1cb5ac6bf [shardformer] Align bert value (#3907)
* add bert align test, fix dist loss bug

* forward and backward align

* add ignore index

* add shardformer CI

* add gather_output optional for user in shardconfig

* update readme with optional gather_ouput

* add dist crossentropy loss test, remove unused files

* remove unused file

* remove unused file

* rename the file

* polish code
2023-07-04 16:05:01 +08:00
Baizhou Zhang 0bb0b481b4 [gemini] fix argument naming during chunk configuration searching 2023-06-25 13:34:15 +08:00
github-actions[bot] a52f62082d
[format] applied code formatting on changed files in pull request 4021 (#4022)
Co-authored-by: github-actions <github-actions@github.com>
2023-06-19 11:23:24 +08:00
Frank Lee a5883aa790
[test] fixed codefactor format report (#4026) 2023-06-16 18:23:02 +08:00
Baizhou Zhang 822c3d4d66
[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002) 2023-06-16 14:14:05 +08:00
Wenhao Chen 725af3eeeb
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost

* test: skip unet test if diffusers version > 0.10.2
2023-06-15 17:38:42 +08:00
Baizhou Zhang c9cff7e7fa
[checkpointio] General Checkpointing of Sharded Optimizers (#3984) 2023-06-15 15:21:26 +08:00
digger yu e61ffc77c6
fix typo tests/ (#3936) 2023-06-09 09:49:41 +08:00
Frank Lee ddcf58cacf
Revert "[sync] sync feature/shardformer with develop" 2023-06-09 09:41:27 +08:00
Frank Lee eb39154d40
[dtensor] updated api and doc (#3845) 2023-06-08 10:18:17 +08:00
Hongxin Liu ae02d4e4f7
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
2023-06-05 15:58:31 +08:00
Hongxin Liu dbb32692d2
[lazy] refactor lazy init (#3891)
* [lazy] remove old lazy init

* [lazy] refactor lazy init folder structure

* [lazy] fix lazy tensor deepcopy

* [test] update lazy init test
2023-06-05 14:20:47 +08:00
wukong1992 6b305a99d6
[booster] torch fsdp fix ckpt (#3788) 2023-05-23 16:58:45 +08:00
Frank Lee 615e2e5fc1
[test] fixed lazy init test import error (#3799) 2023-05-23 11:57:15 +08:00
Hongxin Liu 3c07a2846e
[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)
* [test] refactor torch ddp checkpoint test

* [plugin] update low level zero optim checkpoint

* [plugin] update gemini optim checkpoint
2023-05-19 19:42:31 +08:00
Hongxin Liu 5452df63c5
[plugin] torch ddp plugin supports sharded model checkpoint (#3775)
* [plugin] torch ddp plugin add save sharded model

* [test] fix torch ddp ckpt io test

* [test] fix torch ddp ckpt io test

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] remove debug info
2023-05-18 20:05:59 +08:00
wukong1992 6050f37776
[booster] removed models that don't support fsdp (#3744)
Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>
2023-05-15 19:35:21 +08:00
Hongxin Liu afb239bbf8
[devops] update torch version of CI (#3725)
* [test] fix flop tensor test

* [test] fix autochunk test

* [test] fix lazyinit test

* [devops] update torch version of CI

* [devops] enable testmon

* [devops] fix ci

* [devops] fix ci

* [test] fix checkpoint io test

* [test] fix cluster test

* [test] fix timm test

* [devops] fix ci

* [devops] fix ci

* [devops] fix ci

* [devops] fix ci

* [devops] force sync to test ci

* [test] skip fsdp test
2023-05-15 17:20:56 +08:00
wukong1992 b37797ed3d
[booster] support torch fsdp plugin in booster (#3697)
Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>
2023-05-15 12:14:38 +08:00
digger-yu 1f73609adb
[CI] fix typo with tests/ etc. (#3727)
* fix spelling error with examples/comminity/

* fix spelling error with tests/

* fix some spelling error with tests/ colossalai/ etc.

* fix spelling error with tests/ etc. date:2023.5.10
2023-05-11 16:30:58 +08:00
digger-yu b7141c36dd
[CI] fix some spelling errors (#3707)
* fix spelling error with examples/comminity/

* fix spelling error with tests/

* fix some spelling error with tests/ colossalai/ etc.
2023-05-10 17:12:03 +08:00
jiangmingyan 20068ba188
[booster] add tests for ddp and low level zero's checkpointio (#3715)
* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update tests for booster

* [booster] update booster tutorials#3717, fix recursive check
2023-05-10 12:17:02 +08:00
Hongxin Liu 6552cbf8e1
[booster] fix no_sync method (#3709)
* [booster] fix no_sync method

* [booster] add test for ddp no_sync

* [booster] fix merge

* [booster] update unit test

* [booster] update unit test

* [booster] update unit test
2023-05-09 11:10:02 +08:00
Hongxin Liu 3bf09efe74
[booster] update prepare dataloader method for plugin (#3706)
* [booster] add prepare dataloader method for plug

* [booster] update examples and docstr
2023-05-08 15:44:03 +08:00
Hongxin Liu d0915f54f4
[booster] refactor all dp fashion plugins (#3684)
* [booster] add dp plugin base

* [booster] inherit dp plugin base

* [booster] refactor unit tests
2023-05-05 19:36:10 +08:00
digger-yu b49020c1b1
[CI] Update test_sharded_optim_with_sync_bn.py (#3688)
fix spelling error in line23
change "cudnn_determinstic"=True to "cudnn_deterministic=True"
2023-05-05 18:57:27 +08:00
jiangmingyan 307894f74d
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
2023-05-05 14:37:21 +08:00
Hongxin Liu 50793b35f4
[gemini] accelerate inference (#3641)
* [gemini] support don't scatter after inference

* [chat] update colossalai strategy

* [chat] fix opt benchmark

* [chat] update opt benchmark

* [gemini] optimize inference

* [test] add gemini inference test

* [chat] fix unit test ci

* [chat] fix ci

* [chat] fix ci

* [chat] skip checkpoint test
2023-04-26 16:32:40 +08:00
Hongxin Liu 4b3240cb59
[booster] add low level zero plugin (#3594)
* [booster] add low level zero plugin

* [booster] fix gemini plugin test

* [booster] fix precision

* [booster] add low level zero plugin test

* [test] fix booster plugin test oom

* [test] fix booster plugin test oom

* [test] fix googlenet and inception output trans

* [test] fix diffuser clip vision model

* [test] fix torchaudio_wav2vec2_base

* [test] fix low level zero plugin test
2023-04-26 14:37:25 +08:00
Hongxin Liu f313babd11
[gemini] support save state dict in shards (#3581)
* [gemini] support state dict shard

* [gemini] add test state dict shard

* [gemini] polish docstr

* [gemini] fix merge

* [gemini] polish code
2023-04-17 17:11:09 +08:00
Hongxin Liu 152239bbfa
[gemini] gemini supports lazy init (#3379)
* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
2023-04-12 16:03:25 +08:00
jiangmingyan 52a933e175
[checkpoint] support huggingface style sharded checkpoint (#3461)
* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
2023-04-06 16:23:39 +08:00
Frank Lee 80eba05b0a
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
2023-04-06 14:51:35 +08:00
ver217 933048ad3e
[test] reorganize zero/gemini tests (#3445) 2023-04-06 09:38:25 +08:00
YuliangLiu0306 ffcdbf0f65
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
2023-04-04 17:40:45 +08:00
Frank Lee 1beb85cc25
[checkpoint] refactored the API and added safetensors support (#3427)
* [checkpoint] refactored the API and added safetensors support

* polish code
2023-04-04 15:23:01 +08:00
ver217 26b7aac0be
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
2023-04-04 13:48:16 +08:00
Frank Lee 638a07a7f9
[test] fixed gemini plugin test (#3411)
* [test] fixed gemini plugin test

* polish code

* polish code
2023-04-03 17:12:22 +08:00
ver217 5f2e34e6c9
[booster] implement Gemini plugin (#3352)
* [booster] add gemini plugin

* [booster] update docstr

* [booster] gemini plugin add coloparam convertor

* [booster] fix coloparam convertor

* [booster] fix gemini plugin device

* [booster] add gemini plugin test

* [booster] gemini plugin ignore sync bn

* [booster] skip some model

* [booster] skip some model

* [booster] modify test world size

* [booster] modify test world size

* [booster] skip test
2023-03-31 16:06:13 +08:00
HELSON 1a1d68b053
[moe] add checkpoint for moe models (#3354)
* [moe] add checkpoint for moe models

* [hotfix] fix bugs in unit test
2023-03-31 09:20:33 +08:00
YuliangLiu0306 fee2af8610
[autoparallel] adapt autoparallel with new analyzer (#3261)
* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
2023-03-30 17:47:24 +08:00
Frank Lee 73d3e4d309
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example

* polish code
2023-03-27 10:24:14 +08:00
YuliangLiu0306 4d5d8f98a4
[API] implement device mesh manager (#3221)
* [API] implement  device mesh manager

* polish
2023-03-24 13:39:12 +08:00
YuliangLiu0306 045afa3ea2
[hotfix] skip torchaudio tracing test (#3211)
* [hotfix] skip torchaudio tracing test

* fix lazy init test issue
2023-03-24 12:15:33 +08:00
Frank Lee cd142fbefa
[api] implemented the checkpoint io module (#3205)
* [api] implemented the checkpoint io module

* polish code

* polish code
2023-03-23 10:53:17 +08:00
ver217 f8289d4221
[lazyinit] combine lazy tensor with dtensor (#3204)
* [lazyinit] lazy tensor add distribute

* [lazyinit] refactor distribute

* [lazyinit] add test dist lazy init

* [lazyinit] add verbose info for dist lazy init

* [lazyinit] fix rnn flatten weight op

* [lazyinit] polish test

* [lazyinit] polish test

* [lazyinit] fix lazy tensor data setter

* [lazyinit] polish test

* [lazyinit] fix clean

* [lazyinit] make materialize inplace

* [lazyinit] refactor materialize

* [lazyinit] refactor test distribute

* [lazyinit] fix requires_grad

* [lazyinit] fix tolist after materialization

* [lazyinit] refactor distribute module

* [lazyinit] polish docstr

* [lazyinit] polish lazy init context

* [lazyinit] temporarily skip test

* [lazyinit] polish test

* [lazyinit] add docstr
2023-03-23 10:53:06 +08:00
YuliangLiu0306 019a847432
[Analyzer] fix analyzer tests (#3197) 2023-03-22 13:38:11 +08:00
YuliangLiu0306 f57d34958b
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
2023-03-22 10:40:33 +08:00
Frank Lee e7f3bed2d3
[booster] added the plugin base and torch ddp plugin (#3180)
* [booster] added the plugin base and torch ddp plugin

* polish code

* polish code

* polish code
2023-03-21 17:39:30 +08:00
Zihao 18dbe76cae
[auto-parallel] add auto-offload feature (#3154)
* add auto-offload feature

* polish code

* fix syn offload runtime pass bug

* add offload example

* fix offload testing bug

* fix example testing bug
2023-03-21 14:17:41 +08:00
zbian 7bc0afc901 updated flash attention usage 2023-03-20 17:57:04 +08:00
Frank Lee 085e7f4eff
[test] fixed torchrec registration in model zoo (#3177)
* [test] fixed torchrec registration in model zoo

* polish code

* polish code

* polish code
2023-03-20 16:19:06 +08:00
Frank Lee a9b8402d93
[booster] added the accelerator implementation (#3159) 2023-03-20 13:59:24 +08:00
Frank Lee 1ad3a636b1
[test] fixed torchrec model test (#3167)
* [test] fixed torchrec model test

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
2023-03-20 11:40:25 +08:00
ver217 6ae8ed0407
[lazyinit] add correctness verification (#3147)
* [lazyinit] fix shared module

* [tests] add lazy init test utils

* [tests] add torchvision for lazy init

* [lazyinit] fix pre op fn

* [lazyinit] handle legacy constructor

* [tests] refactor lazy init test models

* [tests] refactor lazy init test utils

* [lazyinit] fix ops don't support meta

* [tests] lazy init test timm models

* [lazyinit] fix set data

* [lazyinit] handle apex layers

* [tests] lazy init test transformers models

* [tests] lazy init test torchaudio models

* [lazyinit] fix import path

* [tests] lazy init test torchrec models

* [tests] update torch version in CI

* [tests] revert torch version in CI

* [tests] skip lazy init test
2023-03-17 13:49:04 +08:00
Frank Lee ed19290560
[booster] implemented mixed precision class (#3151)
* [booster] implemented mixed precision class

* polish code
2023-03-17 11:00:15 +08:00
YuliangLiu0306 ecd643f1e4
[test] add torchrec models to test model zoo (#3139) 2023-03-15 05:46:04 +00:00
ver217 14a115000b
[tests] model zoo add torchaudio models (#3138)
* [tests] model zoo add torchaudio models

* [tests] refactor torchaudio wavernn

* [tests] refactor fx torchaudio tests
2023-03-15 11:51:16 +08:00
Frank Lee 6d48eb0560
[test] added transformers models to test model zoo (#3135) 2023-03-15 11:26:10 +08:00
Frank Lee a674c63348
[test] added torchvision models to test model zoo (#3132)
* [test] added torchvision models to test model zoo

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
2023-03-15 10:42:07 +08:00
HELSON 1216d1e7bd
[tests] diffuser models in model zoo (#3136)
* [tests] diffuser models in model zoo

* remove useless code

* [tests] add diffusers to requirement-test
2023-03-14 17:20:28 +08:00
YuliangLiu0306 2eca4cd376
[DTensor] refactor dtensor with new components (#3089)
* [DTensor] refactor dtensor with new components

* polish
2023-03-14 16:25:47 +08:00
Frank Lee 86ac782d7c
[test] added timm models to test model zoo (#3129)
* [test] added timm models to test model zoo

* polish code

* polish code

* polish code

* polish code

* polish code
2023-03-14 14:29:18 +08:00
Xuanlei Zhao 30dd13c450
[autochunk] support complete benchmark (#3121)
* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test

* add unet test

* add bench

* update bench

* update bench

* init

* support vit

* move to cpu

* add cpu benchmark
2023-03-13 17:42:37 +08:00
Super Daniel fff98f06ed
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device.

* [siu] add experimental submodules to main branch.

* [siu]

* [siu]

* [analyzer] init.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [test] add test.

* Update symbolic_trace.py

* mark skip tests.

* try except.

* try except.

* try except.

* s

* init

* init

* fix

* skip

* skip

---------

Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com>
Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
2023-03-10 13:21:05 +08:00
Xuanlei Zhao 10c61de2f7
[autochunk] support vit (#3084)
support vit for autochunk
* support some new ops for vit
* fix some bugs
* add test for vit
2023-03-10 10:23:26 +08:00
YuliangLiu0306 8e4e8601b7
[DTensor] implement layout converter (#3055)
* [DTensor] refactor LayoutConverter for DTensor

* polish code

* polish docstring
2023-03-10 09:53:52 +08:00
Xuanlei Zhao 2ca9728cbb
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
2023-03-08 16:22:30 +08:00