[Feature] LoRA rebased to main branch (#5622)
* [Inference]ADD Bench Chatglm2 script (#4963)
* add bench chatglm
* fix bug and make utils
---------
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [Pipeline inference] Combine kvcache with pipeline inference (#4938)
* merge kvcache with pipeline inference and refactor the code structure
* support ppsize > 2
* refactor pipeline code
* do pre-commit
* modify benchmark
* fix bench mark
* polish code
* add docstring and update readme
* refactor the code
* fix some logic bug of ppinfer
* polish readme
* fix typo
* skip infer test
* updated c++17 compiler flags (#4983)
* [Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)
* finish batch manager
* 1
* first
* fix
* fix dynamic batching
* llama infer
* finish test
* support different lengths generating
* del prints
* del prints
* fix
* fix bug
---------
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [inference] Async dynamic batching (#4894)
* finish input and output logic
* add generate
* test forward
* 1
* [inference]Re push async dynamic batching (#4901)
* adapt to ray server
* finish async
* finish test
* del test
---------
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)
This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a.
* Revert "[inference] Async dynamic batching (#4894)"
This reverts commit fced14025043e29ce816b315f440601188f7f79f.
* Revert "[inference] Async dynamic batching (#4894)" (#4909)
This reverts commit fced14025043e29ce816b315f440601188f7f79f.
* Add Ray Distributed Environment Init Scripts
* support DynamicBatchManager base function
* revert _set_tokenizer version
* add driver async generate
* add async test
* fix bugs in test_ray_dist.py
* add get_tokenizer.py
* fix code style
* fix bugs about No module named 'pydantic' in ci test
* fix bugs in ci test
* fix bugs in ci test
* fix bugs in ci test
* [infer]Add Ray Distributed Environment Init Scripts (#4911)
* Revert "[inference] Async dynamic batching (#4894)"
This reverts commit fced14025043e29ce816b315f440601188f7f79f.
* Add Ray Distributed Environment Init Scripts
* support DynamicBatchManager base function
* revert _set_tokenizer version
* add driver async generate
* add async test
* fix bugs in test_ray_dist.py
* add get_tokenizer.py
* fix code style
* fix bugs about No module named 'pydantic' in ci test
* fix bugs in ci test
* fix bugs in ci test
* fix bugs in ci test
* support dynamic batch for bloom model and is_running function
* [Inference]Test for new Async engine (#4935)
* infer engine
* infer engine
* test engine
* test engine
* new manager
* change step
* add
* test
* fix
* fix
* finish test
* finish test
* finish test
* finish test
* add license
---------
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
* add assertion for config (#4947)
* [Inference] Finish dynamic batching offline test (#4948)
* test
* fix test
* fix quant
* add default
* fix
* fix some bugs
* fix some bugs
* fix
* fix bug
* fix bugs
* reset param
---------
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965)
* adding flash-decoding
* clean
* adding kernel
* adding flash-decoding
* add integration
* add
* adding kernel
* adding kernel
* adding triton 2.1.0 features for inference
* update bloom triton kernel
* remove useless vllm kernels
* clean codes
* fix
* adding files
* fix readme
* update llama flash-decoding
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
* fix ColossalEval (#4992)
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
* [doc]Update doc for colossal-inference (#4989)
* update doc
* Update README.md
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
* [hotfix] Fix the bug where process groups were not being properly released. (#4940)
* Fix the bug where process groups were not being properly released.
* test
* Revert "test"
This reverts commit 479900c1398637310abf92eefa3cd168038ea02f.
* [hotfix] fix the bug of repeatedly storing param group (#4951)
* [doc] add supported feature diagram for hybrid parallel plugin (#4996)
* [Pipeline Inference] Merge pp with tp (#4993)
* refactor pipeline into new CaiInferEngine
* updata llama modeling forward
* merge tp with pp
* update docstring
* optimize test workflow and example
* fix typo
* add assert and todo
* [release] update version (#4995)
* [release] update version
* [hotfix] fix ci
* [moe] merge moe into main (#4978)
* update moe module
* support openmoe
* [hotfix] fix grad accumulation plus clipping for gemini (#5002)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)
* Add layer norm gradients all-reduce for sequence parallel.
* skip pipeline inference test
* [hotfix] fixing polices of sequence parallel (#4922)
* Add layer norm gradients all-reduce for sequence parallel.
* fix parameter passing when calling get_autopolicy
---------
Co-authored-by: littsk <1214689160@qq.com>
* Hotfix/add grad all reduce for sequence parallel (#4927)
* Add layer norm gradients all-reduce for sequence parallel.
* fix parameter passing when calling get_autopolicy
* fix bug using wrong variables
---------
Co-authored-by: littsk <1214689160@qq.com>
* fix policy initialization
* fix bloom and chatglm policices
* polish code of handling layernorm
* fix moe module
* polish code of class initializing
---------
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
* [format] applied code formatting on changed files in pull request 4926 (#5007)
Co-authored-by: github-actions <github-actions@github.com>
* [Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)
* fix bug
* fix
* fix multiquery
* fix multiquery
---------
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [misc] add code owners (#5024)
* [moe] support optimizer checkpoint (#5015)
* Refactor MoE Manager setup method
* unshard optim ckpt
* optim io
* update transformer version
* update requirements
* update ckpt
* update ckpt
* update ckpt
* fix engine
* fix engine
* Support mtbench (#5025)
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
* [moe]: fix ep/tp tests, add hierarchical all2all (#4982)
* fix: add warning for EP different behavior
* fix: use shard_data in ep & tp model
* to: add used_capacity
* fix: fix router test
* feat: add create_ep_node_group
* feat: add create_ep_hierarchical_group fn
* feat: add HierarchicalAllToAll
* test: add hierarchical all2all test
* fix: fix test errors
* fix: simplify create_ep_hierarchical_group
* fix: add hierarchical_alltoall arg
* fix: fix environ typo
* revert: revert process mesh order
* to: add todo mark
* fix: skip hierarchical_comm if torch < 1.13.1
* [shardformer] Fix serialization error with Tensor Parallel state saving (#5018)
* Fix serialization error with Tensor Parallel state saving
* Refactor state_dict CPU transfer using tree_map
* [gemini] gemini support tensor parallelism. (#4942)
* [colossalai]fix typo
* [inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843)
* [inference] add smoothquant llama attention (#4850)
* add smoothquant llama attention
* remove uselss code
* remove useless code
* fix import error
* rename file name
* [inference] add silu linear fusion for smoothquant llama mlp (#4853)
* add silu linear
* update skip condition
* catch smoothquant cuda lib exception
* prcocess exception for tests
* [inference] add llama mlp for smoothquant (#4854)
* add llama mlp for smoothquant
* fix down out scale
* remove duplicate lines
* add llama mlp check
* delete useless code
* [inference] add smoothquant llama (#4861)
* add smoothquant llama
* fix attention accuracy
* fix accuracy
* add kv cache and save pretrained
* refactor example
* delete smooth
* refactor code
* [inference] add smooth function and delete useless code for smoothquant (#4895)
* add smooth function and delete useless code
* update datasets
* remove duplicate import
* delete useless file
* refactor codes (#4902)
* rafactor code
* add license
* add torch-int and smoothquant license
* Update flash_attention_patch.py
To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598
* [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)
* [kernel] support pure fp16 for cpu adam (#4896)
* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)
* [kernel] fix cpu adam
* [test] update gemini optim test
* [format] applied code formatting on changed files in pull request 4908 (#4918)
Co-authored-by: github-actions <github-actions@github.com>
* [gemini] support gradient accumulation (#4869)
* add test
* fix no_sync bug in low level zero plugin
* fix test
* add argument for grad accum
* add grad accum in backward hook for gemini
* finish implementation, rewrite tests
* fix test
* skip stuck model in low level zero test
* update doc
* optimize communication & fix gradient checkpoint
* modify doc
* cleaning codes
* update cpu adam fp16 case
* [hotfix] fix torch 2.0 compatibility (#4936)
* [hotfix] fix launch
* [test] fix test gemini optim
* [shardformer] fix vit
* [test] add no master test for low level zero plugin (#4934)
* [format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <github-actions@github.com>
* [nfc] fix some typo with colossalai/ docs/ etc. (#4920)
* [Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference
* clean codes
* add codes
* add some lightllm deps
* clean codes
* hello
* delete rms files
* add some comments
* add comments
* add doc
* add lightllm deps
* add lightllm cahtglm2 kernels
* add lightllm cahtglm2 kernels
* replace rotary embedding with lightllm kernel
* add some commnets
* add some comments
* add some comments
* add
* replace fwd kernel att1
* fix a arg
* add
* add
* fix token attention
* add some comments
* clean codes
* modify comments
* fix readme
* fix bug
* fix bug
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo
* [test] update legacy test
* [test] update model zoo
* [test] update gemini test
* [test] remove components to test
* [inference] add reference and fix some bugs (#4937)
* add reference and fix some bugs
* update gptq init
---------
Co-authored-by: Xu Kai <xukai16@foxamil.com>
* [Inference]ADD Bench Chatglm2 script (#4963)
* add bench chatglm
* fix bug and make utils
---------
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [Pipeline inference] Combine kvcache with pipeline inference (#4938)
* merge kvcache with pipeline inference and refactor the code structure
* support ppsize > 2
* refactor pipeline code
* do pre-commit
* modify benchmark
* fix bench mark
* polish code
* add docstring and update readme
* refactor the code
* fix some logic bug of ppinfer
* polish readme
* fix typo
* skip infer test
* updated c++17 compiler flags (#4983)
* [Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)
* finish batch manager
* 1
* first
* fix
* fix dynamic batching
* llama infer
* finish test
* support different lengths generating
* del prints
* del prints
* fix
* fix bug
---------
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [inference] Async dynamic batching (#4894)
* finish input and output logic
* add generate
* test forward
* 1
* [inference]Re push async dynamic batching (#4901)
* adapt to ray server
* finish async
* finish test
* del test
---------
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)
This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a.
* Revert "[inference] Async dynamic batching (#4894)"
This reverts commit fced14025043e29ce816b315f440601188f7f79f.
* Revert "[inference] Async dynamic batching (#4894)" (#4909)
This reverts commit fced14025043e29ce816b315f440601188f7f79f.
* Add Ray Distributed Environment Init Scripts
* support DynamicBatchManager base function
* revert _set_tokenizer version
* add driver async generate
* add async test
* fix bugs in test_ray_dist.py
* add get_tokenizer.py
* fix code style
* fix bugs about No module named 'pydantic' in ci test
* fix bugs in ci test
* fix bugs in ci test
* fix bugs in ci test
* [infer]Add Ray Distributed Environment Init Scripts (#4911)
* Revert "[inference] Async dynamic batching (#4894)"
This reverts commit fced14025043e29ce816b315f440601188f7f79f.
* Add Ray Distributed Environment Init Scripts
* support DynamicBatchManager base function
* revert _set_tokenizer version
* add driver async generate
* add async test
* fix bugs in test_ray_dist.py
* add get_tokenizer.py
* fix code style
* fix bugs about No module named 'pydantic' in ci test
* fix bugs in ci test
* fix bugs in ci test
* fix bugs in ci test
* support dynamic batch for bloom model and is_running function
* [Inference]Test for new Async engine (#4935)
* infer engine
* infer engine
* test engine
* test engine
* new manager
* change step
* add
* test
* fix
* fix
* finish test
* finish test
* finish test
* finish test
* add license
---------
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
* add assertion for config (#4947)
* [Inference] Finish dynamic batching offline test (#4948)
* test
* fix test
* fix quant
* add default
* fix
* fix some bugs
* fix some bugs
* fix
* fix bug
* fix bugs
* reset param
---------
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965)
* adding flash-decoding
* clean
* adding kernel
* adding flash-decoding
* add integration
* add
* adding kernel
* adding kernel
* adding triton 2.1.0 features for inference
* update bloom triton kernel
* remove useless vllm kernels
* clean codes
* fix
* adding files
* fix readme
* update llama flash-decoding
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
* fix ColossalEval (#4992)
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
* [doc]Update doc for colossal-inference (#4989)
* update doc
* Update README.md
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
* [hotfix] Fix the bug where process groups were not being properly released. (#4940)
* Fix the bug where process groups were not being properly released.
* test
* Revert "test"
This reverts commit 479900c1398637310abf92eefa3cd168038ea02f.
* [hotfix] fix the bug of repeatedly storing param group (#4951)
* [doc] add supported feature diagram for hybrid parallel plugin (#4996)
* [Pipeline Inference] Merge pp with tp (#4993)
* refactor pipeline into new CaiInferEngine
* updata llama modeling forward
* merge tp with pp
* update docstring
* optimize test workflow and example
* fix typo
* add assert and todo
* [release] update version (#4995)
* [release] update version
* [hotfix] fix ci
* [gemini] gemini support tp
[gemini] gemini support tp
[gemini] gemini support tp
[gemini] gemini support tp
[gemini] gemini support tp
* fix
fix
fix
* update checkpointIO
update checkpointIO
update checkpointIO
update checkpointIO
update checkpointIO
update checkpointIO
update checkpointIO
update checkpointIO
update checkpointIO
* support fused layernorm
support fused layernorm
support fused layernorm
* update fusedlayernorm
update fusedlayernorm
update fusedlayernorm
* add sequence parallel to gemini
add sequence parallel to gemini
* fix
* fix comments
fix comments
fix comments
* fix
* fix t5
* clear cache
* fix
* activate ci
* activate ci
* fix
* fix
* fix
* fix
* revert
* modify tp gather method
modify tp gather method
modify tp gather method
modify tp gather method
* fix test
---------
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
* [hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig
* modify shardconfig
* polish code
* fix policy bug in llama
* fix bug in auto policy
* remove setattr in ShardConfig
* fix wrong EOS token in ColossalChat
* [Kernels]Update triton kernels into 2.1.0 (#5046)
* update flash-context-attention
* adding kernels
* fix
* reset
* add build script
* add building process
* add llama2 exmaple
* add colossal-llama2 test
* clean
* fall back test setting
* fix test file
* clean
* clean
* clean
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
* [pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when `strict=False`, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)
* Use p2p
* Cannot bidirectonal send p2p
* Refactor tensor creation and serialization in P2P
communication
* Fix llama forward args in flash attention
* Add flop estimate from megatron
* Support loading weight not in weight_map when strict=False in hybrid_parallel
* Use send_forward_recv_backward, etc in 1f1b
* Use dataclass for metdata
Remove torch.cuda.synchronize() as suggested
* Add comment about the torch.cuda.synchronize for potential error
* Typo
* Update hybrid_parallel_checkpoint_io.py
* Update p2p.py
* Update one_f_one_b.py
* Update p2p.py
---------
Co-authored-by: flybird11111 <1829166702@qq.com>
* [gemini] gemini support extra-dp (#5043)
* support ddp
* fix
* fix
* fix
fix
* support ddp
* fix
* fix
* fix
fix
* simplify tests
* fix
* fix
* fix
fix
fix
* fix
* [shardformer] fix llama error when transformers upgraded. (#5055)
* fix-llama
* Update llama.py
* [hotfix]: modify create_ep_hierarchical_group and add test (#5032)
* feat: modify create_ep_hierarchical_group args
* test: add ep tests
* fix: remove get_process_group_ranks
* fix: fix src_rank
* [exampe] fix llama example' loss error when using gemini plugin (#5060)
fix llama example
* [inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)
* support only tp
* enable tp
* add support for bloom (#5008)
* [refactor] refactor gptq and smoothquant llama (#5012)
* refactor gptq and smoothquant llama
* fix import error
* fix linear import torch-int
* fix smoothquant llama import error
* fix import accelerate error
* fix bug
* fix import smooth cuda
* fix smoothcuda
* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)
merge chatglm with pp and tp
* [Refactor] remove useless inference code (#5022)
* remove useless code
* fix quant model
* fix test import bug
* mv original inference legacy
* fix chatglm2
* [Refactor] refactor policy search and quant type controlling in inference (#5035)
* [Refactor] refactor policy search and quant type controling in inference
* [inference] update readme (#5051)
* update readme
* update readme
* fix architecture
* fix table
* fix table
* [inference] udpate example (#5053)
* udpate example
* fix run.sh
* fix rebase bug
* fix some errors
* update readme
* add some features
* update interface
* update readme
* update benchmark
* add requirements-infer
---------
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
* [Kernels]added flash-decoidng of triton (#5063)
* added flash-decoidng of triton based on lightllm kernel
* add req
* clean
* clean
* delete build.sh
---------
Co-authored-by: cuiqing.li <lixx336@gmail.com>
* [misc] remove outdated submodule (#5070)
* [npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047)
* [npu] add npu device support
* [npu] support low level zero
* [test] update npu zero plugin test
* [hotfix] fix import
* [test] recover tests
* [npu] gemini support npu (#5052)
* [npu] refactor device utils
* [gemini] support npu
* [example] llama2+gemini support npu
* [kernel] add arm cpu adam kernel (#5065)
* [kernel] add arm cpu adam
* [optim] update adam optimizer
* [kernel] arm cpu adam remove bf16 support
* [hotfix/hybridengine] fix bug when tp*pp size = 1 (#5069)
* [inference] update examples and engine (#5073)
* update examples and engine
* fix choices
* update example
* [format] applied code formatting on changed files in pull request 5067 (#5072)
Co-authored-by: github-actions <github-actions@github.com>
* [hotfix/hybridengine] Fix init model with random parameters in benchmark (#5074)
* fix init model with random parameters
* fix example
* [inference] refactor examples and fix schedule (#5077)
* [setup] refactor infer setup
* [hotfix] fix infenrece behavior on 1 1 gpu
* [exmaple] refactor inference examples
* fix thrust-transform-reduce error (#5078)
* [nfc] fix typo in docs/ (#4972)
* [nfc] fix typo and author name (#5089)
* [gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)
* [Hotfix] Fix model policy matching strategy in ShardFormer (#5064)
* hotfix/Fix get model policy strategy in ShardFormer
* fix bug in auto policy
* [shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)
* fix flash attn
* fix
fix
* [npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d
* update
* fix autocast
* [Feature] Add document retrieval QA (#5020)
* add langchain
* add langchain
* Add files via upload
* add langchain
* fix style
* fix style: remove extra space
* add pytest; modified retriever
* add pytest; modified retriever
* add tests to build_on_pr.yml
* fix build_on_pr.yml
* fix build on pr; fix environ vars
* seperate unit tests for colossalqa from build from pr
* fix container setting; fix environ vars
* commented dev code
* add incremental update
* remove stale code
* fix style
* change to sha3 224
* fix retriever; fix style; add unit test for document loader
* fix ci workflow config
* fix ci workflow config
* add set cuda visible device script in ci
* fix doc string
* fix style; update readme; refactored
* add force log info
* change build on pr, ignore colossalqa
* fix docstring, captitalize all initial letters
* fix indexing; fix text-splitter
* remove debug code, update reference
* reset previous commit
* update LICENSE update README add key-value mode, fix bugs
* add files back
* revert force push
* remove junk file
* add test files
* fix retriever bug, add intent classification
* change conversation chain design
* rewrite prompt and conversation chain
* add ui v1
* ui v1
* fix atavar
* add header
* Refactor the RAG Code and support Pangu
* Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo.
* resolved conversation. tested scripts under examples. web demo still buggy
* fix ci tests
* Some modifications to add ChatGPT api
* modify llm.py and remove unnecessary files
* Delete applications/ColossalQA/examples/ui/test_frontend_input.json
* Remove OpenAI api key
* add colossalqa
* move files
* move files
* move files
* move files
* fix style
* Add Readme and fix some bugs.
* Add something to readme and modify some code
* modify a directory name for clarity
* remove redundant directory
* Correct a type in llm.py
* fix AI prefix
* fix test_memory.py
* fix conversation
* fix some erros and typos
* Fix a missing import in RAG_ChatBot.py
* add colossalcloud LLM wrapper, correct issues in code review
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
* remove duplicate import (#5100)
* fix typo change lazy_iniy to lazy_init (#5099)
* [nfc] fix typo change directoty to directory (#5111)
* [FEATURE] Add Safety Eval Datasets to ColossalEval (#5095)
* add safetybench and cvalues(responsibility) eval dataset
* Modify code according to review suggestions
---------
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
* [hotfix] fixed memory usage of shardformer module replacement (#5122)
* [shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)
* [shardformer] implement policy for all GPT-J models and test
* [shardformer] support interleaved pipeline parallel for bert finetune
* [shardformer] shardformer support falcon (#4883)
* [shardformer]: fix interleaved pipeline for bert model (#5048)
* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)
* Add Mistral support for Shardformer (#5103)
* [shardformer] add tests to mistral (#5105)
---------
Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
* [doc] add moe news (#5128)
* [doc] add moe news
* [doc] add moe news
* [doc] add moe news
* [doc] updated paper citation (#5131)
* fix typo change JOSNL TO JSONL etc. (#5116)
* [format] applied code formatting on changed files in pull request 5088 (#5127)
Co-authored-by: github-actions <github-actions@github.com>
* [format] applied code formatting on changed files in pull request 5124 (#5125)
Co-authored-by: github-actions <github-actions@github.com>
* [format] applied code formatting on changed files in pull request 5115 (#5118)
Co-authored-by: github-actions <github-actions@github.com>
* [accelerator] init the accelerator module (#5129)
* [accelerator] init the accelerator module
* polish code
* polish code
* polish code
* polish code
* [npu] support triangle attention for llama (#5130)
* update fused attn
* update spda
* tri attn
* update triangle
* import
* fix
* fix
* [plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135)
* fix 3d checkpoint load when booster boost without optimizer
fix 3d checkpoint load when booster boost without optimizer
* test ci
* revert ci
* fix
fix
* [ColossalQA] refactor server and webui & add new feature (#5138)
* refactor server and webui & add new feature
* add requirements
* modify readme and ui
* [doc] fix colossalqa document (#5146)
* fix doc
* modify doc
* fix (#5158)
fix
* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)
* Add finetuning Colossal-Llama-2 example
* Add finetuning Colossal-Llama-2 example 2
* Add finetuning Colossal-Llama-2 example and support NEFTuning
* Add inference example and refine neftune
* Modify readme file
* update the imports
---------
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
* [gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150)
* fix
aaa
fix
fix
fix
* fix
* fix
* test ci
* fix ci
fix
* [colossalqa] fix pangu api (#5170)
* fix pangu api
* add comment
* [ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)
* Support GSM, Data Leakage Evaluation and Tensor Parallel
* remove redundant code and update inference.py in examples/gpt_evaluation
---------
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
* [shardformer] llama support DistCrossEntropy (#5176)
* fix
aaa
fix
fix
fix
* fix
* fix
* test ci
* fix ci
fix
* llama support dist-cross
fix
fix
fix
fix
fix
fix
fix
fix
* fix
* fix
* fix
fix
* test ci
* test ci
* fix
* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)
* Add finetuning Colossal-Llama-2 example
* Add finetuning Colossal-Llama-2 example 2
* Add finetuning Colossal-Llama-2 example and support NEFTuning
* Add inference example and refine neftune
* Modify readme file
* update the imports
---------
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
* llama support dist-cross
fix
fix
fix
fix
fix
fix
fix
fix
* fix
* fix
* fix
fix
* test ci
* test ci
* fix
* fix ci
* fix ci
---------
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
* Fix ColossalEval (#5186)
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
* [doc] update pytorch version in documents. (#5177)
* fix
aaa
fix
fix
fix
* fix
* fix
* test ci
* fix ci
fix
* update pytorch version in documents
* polish readme in application/chat (#5194)
* [pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests
* fix: remove send_forward_recv_forward as p2p op list need to use the same group
* fix: make send and receive atomic
* feat: update P2PComm fn
* feat: add metadata cache in 1f1b
* feat: add metadata cache in interleaved pp
* feat: modify is_xx_stage fn
* revert: add _broadcast_object_list
* feat: add interleaved pp in llama policy
* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
* Improve logic for selecting metrics (#5196)
Co-authored-by: Xu <yuanchen.xu00@gmail.com>
* [doc] Update required third-party library list for testing and torch comptibility checking (#5207)
* doc/update requirements-test.txt
* update torch-cuda compatibility check
* support linear accumulation fusion (#5199)
support linear accumulation fusion
support linear accumulation fusion
fix
* [pipeline]: support arbitrary batch size in forward_only mode (#5201)
* fix: remove drop last in val & test dataloader
* feat: add run_forward_only, support arbitrary bs
* chore: modify ci script
* [pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)
* fix: add fallback order option and update 1f1b
* fix: fix deadlock comm in interleaved pp
* test: modify p2p test
* [devops] update torch versoin in ci (#5217)
* fix-test (#5210)
fix-test
fix-test
* fix flash attn (#5209)
* [nfc] fix typo colossalai/shardformer/ (#5133)
* [Colossal-LLaMA-2] Release Colossal-LLaMA-2-13b-base model (#5224)
* update readme
* update readme
* update link
* update
* update readme
* update
* update
* update
* update title
* update example
* update example
* fix content
* add conclusion
* add license
* update
* update
* update version
* fix minor
* [doc] Update README.md of Colossal-LLAMA2 (#5233)
* Update README.md
* Update README.md
* [doc] Make leaderboard format more uniform and good-looking (#5231)
* Make leaderboard format more unifeid and good-looking
* Update README.md
* Update README.md
* [doc] add Colossal-LLaMA-2-13B (#5234)
* [doc] add Colossal-LLaMA-2-13B
* [doc] add Colossal-LLaMA-2-13B
* [doc] add Colossal-LLaMA-2-13B
* [format] applied code formatting on changed files in pull request 5234 (#5235)
Co-authored-by: github-actions <github-actions@github.com>
* [doc] SwiftInfer release (#5236)
* [doc] SwiftInfer release
* [doc] SwiftInfer release
* [doc] SwiftInfer release
* [doc] SwiftInfer release
* [doc] SwiftInfer release
* [npu] use extension for op builder (#5172)
* update extension
* update cpu adam
* update is
* add doc for cpu adam
* update kernel
* update commit
* update flash
* update memory efficient
* update flash attn
* update flash attention loader
* update api
* fix
* update doc
* update example time limit
* reverse change
* fix doc
* remove useless kernel
* fix
* not use warning
* update
* update
* [pipeline] A more general _communicate in p2p (#5062)
* A more general _communicate
* feat: finish tree_flatten version p2p
* fix: update p2p api calls
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [npu] change device to accelerator api (#5239)
* update accelerator
* fix timer
* fix amp
* update
* fix
* update bug
* add error raise
* fix autocast
* fix set device
* remove doc accelerator
* update doc
* update doc
* update doc
* use nullcontext
* update cpu
* update null context
* change time limit for example
* udpate
* update
* update
* update
* [npu] polish accelerator code
---------
Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
* [hotfix] removed unused flag (#5242)
* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)
* [workflow] fixed build CI (#5240)
* [workflow] fixed build CI
* polish
* polish
* polish
* polish
* polish
* [ci] fixed booster test (#5251)
* [ci] fixed booster test
* [ci] fixed booster test
* [ci] fixed booster test
* [ci] fixed ddp test (#5254)
* [ci] fixed ddp test
* polish
* fix typo in applications/ColossalEval/README.md (#5250)
* [ci] fix shardformer tests. (#5255)
* fix ci
fix
* revert: revert p2p
* feat: add enable_metadata_cache option
* revert: enable t5 tests
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [doc] fix doc typo (#5256)
* [doc] fix annotation display
* [doc] fix llama2 doc
* [hotfix]: add pp sanity check and fix mbs arg (#5268)
* fix: fix misleading mbs arg
* feat: add pp sanity check
* fix: fix 1f1b sanity check
* [workflow] fixed incomplete bash command (#5272)
* [workflow] fixed oom tests (#5275)
* [workflow] fixed oom tests
* polish
* polish
* polish
* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)
* fix ci
fix
* fix test
* revert: revert p2p
* feat: add enable_metadata_cache option
* revert: enable t5 tests
* fix
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)
* support gradients acc
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
* fix
fix
* fix
fix
fix
* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)
* fix auto loading gpt2 tokenizer (#5279)
* [doc] add llama2-13B disyplay (#5285)
* Update README.md
* fix 13b typo
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* fix llama pretrain (#5287)
* [hotfix] fix 3d plugin test (#5292)
* fix bug for mefture (#5299)
* [NFC] polish applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py code style (#5228)
* fix some typo (#5307)
* [feat] refactored extension module (#5298)
* [feat] refactored extension module
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* [workflow] updated CI image (#5318)
* [accelerator] fixed npu api
* [tests] fix t5 test. (#5322)
* [ci] fix shardformer tests. (#5255)
* fix ci
fix
* revert: revert p2p
* feat: add enable_metadata_cache option
* revert: enable t5 tests
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* fix t5 test
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [doc] added docs for extensions (#5324)
* [doc] added docs for extensions
* polish
* polish
* fix typo under extensions/ (#5330)
* fix typo change dosen't to doesn't (#5308)
* [extension] fixed exception catch (#5342)
* [Chat] fix sft loss nan (#5345)
* fix script
* fix script
* fix chat nan
* fix chat nan
* [checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)
* [checkpointio] fix hybrid parallel optim checkpoint
* [extension] fix cuda extension
* [checkpointio] fix gemini optimizer checkpoint
* polish code
* [fix] remove unnecessary dp_size assert (#5351)
* fix: remove unnecessary assert
* test: add more 3d plugin tests
* fix: add warning
* [gemini] fix param op hook when output is tuple (#5355)
* [gemini] fix param op hook when output is tuple
* [gemini] fix param op hook
* [llama] fix dataloader for hybrid parallel (#5358)
* [plugin] refactor prepare dataloader
* [plugin] update train script
* [llama] update training script (#5360)
* [llama] update training script
* [doc] polish docstr
* [llama] add flash attn patch for npu (#5362)
* [llama] fix neftune & pbar with start_step (#5364)
* [eval] update llama npu eval (#5366)
* [llama] polish training script and fix optim ckpt (#5368)
* [lr-scheduler] fix load state dict and add test (#5369)
* [llama] fix memory issue (#5371)
* [llama] fix memory issue
* [llama] add comment
* [moe] init mixtral impl
* [moe] update capacity computing (#5253)
* [moe] top2 allow uneven input
* [moe] update capacity computing
* [moe] remove debug info
* [moe] update capacity computing
* [moe] update capacity computing
* [moe] support mixtral (#5309)
* [moe] add mixtral block for single expert
* [moe] mixtral block fwd support uneven ep
* [moe] mixtral block bwd support uneven ep
* [moe] add mixtral moe layer
* [moe] simplify replace
* [meo] support save sharded mixtral
* [meo] support load sharded mixtral
* [meo] support save sharded optim
* [meo] integrate moe manager into plug
* [meo] fix optimizer load
* [meo] fix mixtral layer
* [moe] fix mixtral checkpoint io (#5314)
* [moe] fix mixtral forward default value (#5329)
* [moe] fix mixtral optim checkpoint (#5344)
* [moe] fix tests
* [release] update version (#5380)
* [llama] fix training and inference scripts (#5384)
* [llama] refactor inference example to fit sft
* [llama] fix training script to fit gemini
* [llama] fix inference script
* [doc] Fix typo (#5361)
* [doc] updated installation command (#5389)
* [hotfix] fix variable type for top_p (#5313)
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [hotfix] Fix wrong import in meta_registry (#5392)
* [extension] hotfix jit extension setup (#5402)
* [example] reuse flash attn patch (#5400)
* [fsdp] impl save/load shard model/optimizer (#5357)
* [setup] fixed nightly release (#5388)
* [shardformer]gather llama logits (#5398)
* gather llama logits
* fix
* update requirements (#5407)
* [workflow] added pypi channel (#5412)
* [doc] fix blog link
* [doc] fix blog link
* fix sft single turn inference example (#5416)
* [example]add gpt2 benchmark example script. (#5295)
* benchmark gpt2
* fix
fix
fix
fix
* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)
* [workflow] fixed build CI (#5240)
* [workflow] fixed build CI
* polish
* polish
* polish
* polish
* polish
* [ci] fixed booster test (#5251)
* [ci] fixed booster test
* [ci] fixed booster test
* [ci] fixed booster test
* [ci] fixed ddp test (#5254)
* [ci] fixed ddp test
* polish
* fix typo in applications/ColossalEval/README.md (#5250)
* [ci] fix shardformer tests. (#5255)
* fix ci
fix
* revert: revert p2p
* feat: add enable_metadata_cache option
* revert: enable t5 tests
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [doc] fix doc typo (#5256)
* [doc] fix annotation display
* [doc] fix llama2 doc
* [hotfix]: add pp sanity check and fix mbs arg (#5268)
* fix: fix misleading mbs arg
* feat: add pp sanity check
* fix: fix 1f1b sanity check
* [workflow] fixed incomplete bash command (#5272)
* [workflow] fixed oom tests (#5275)
* [workflow] fixed oom tests
* polish
* polish
* polish
* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)
* fix ci
fix
* fix test
* revert: revert p2p
* feat: add enable_metadata_cache option
* revert: enable t5 tests
* fix
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)
* support gradients acc
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
fix
* fix
fix
* fix
fix
fix
* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)
* fix auto loading gpt2 tokenizer (#5279)
* [doc] add llama2-13B disyplay (#5285)
* Update README.md
* fix 13b typo
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* fix llama pretrain (#5287)
* fix
* fix
* fix
fix
* fix
fix
fix
* fix
fix
* benchmark gpt2
* fix
fix
fix
fix
* [workflow] fixed build CI (#5240)
* [workflow] fixed build CI
* polish
* polish
* polish
* polish
* polish
* [ci] fixed booster test (#5251)
* [ci] fixed booster test
* [ci] fixed booster test
* [ci] fixed booster test
* fix
fix
* fix
fix
fix
* fix
* fix
fix
fix
fix
fix
* fix
* Update shardformer.py
---------
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
* [doc] sora release (#5425)
* [doc] sora release
* [doc] sora release
* [doc] sora release
* [doc] sora release
* [devops] fix extention building (#5427)
* [hotfix] fix sd vit import error (#5420)
* fix import error
* Update dpt_depth.py
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [hotfix] fix typo of openmoe model source (#5403)
* [doc] update some translations with README-zh-Hans.md (#5382)
* [hotfix] fix typo change _descrption to _description (#5331)
* [hotfix] fix typo change enabel to enable under colossalai/shardformer/ (#5317)
* [eval-hotfix] set few_shot_data to None when few shot is disabled (#5422)
* [hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [doc] Fix typo s/infered/inferred/ (#5288)
Signed-off-by: hugo-syn <hugo.vincent@synacktiv.com>
* [hotfix] fix stable diffusion inference bug. (#5289)
* Update train_ddp.yaml
delete "strategy" to fix DDP config loading bug in "main.py"
* Update train_ddp.yaml
fix inference with scripts/txt2img.py config file load bug.
* Update README.md
add pretrain model test code.
* [colossal-llama2] add stream chat examlple for chat version model (#5428)
* add stream chat for chat version
* remove os.system clear
* modify function name
* [release] update version (#5411)
* fix tensor data update for gemini loss caluculation (#5442)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
* [release] grok-1 314b inference (#5490)
* [release] grok-1 inference
* [release] grok-1 inference
* [release] grok-1 inference
* [example] update Grok-1 inference (#5495)
* revise grok-1 example
* remove unused arg in scripts
* prevent re-installing torch
* update readme
* revert modifying colossalai requirements
* add perf
* trivial
* add tokenizer url
* [hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch
* fix: use return_outputs=False to eliminate extra memory consumption
* feat: add return_outputs warning
* style: remove `return_outputs=False` as it is the default value
* [release] grok-1 inference benchmark (#5500)
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [shardformer]Fix lm parallel. (#5480)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* fix lm forward distribution
* fix
* test ci
* fix
* [fix] fix grok-1 example typo (#5506)
* [devops] fix example test ci (#5504)
* Fix ColoTensorSpec for py11 (#5440)
* fixed layout converter caching and updated tester
* Empty-Commit
* [shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)
* [extension] update api
* [feature] add colo attention
* [feature] update sdpa
* [feature] update npu attention
* [feature] update flash-attn
* [test] add flash attn test
* [test] update flash attn test
* [shardformer] update modeling to fit colo attention (#5465)
* [misc] refactor folder structure
* [shardformer] update llama flash-attn
* [shardformer] fix llama policy
* [devops] update tensornvme install
* [test] update llama test
* [shardformer] update colo attn kernel dispatch
* [shardformer] update blip2
* [shardformer] update chatglm
* [shardformer] update gpt2
* [shardformer] update gptj
* [shardformer] update opt
* [shardformer] update vit
* [shardformer] update colo attention mask prep
* [shardformer] update whisper
* [test] fix shardformer tests (#5514)
* [test] fix shardformer tests
* [test] fix shardformer tests
* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>
* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
* Change static methods for t5 layer distribution to member functions
* Change static methods for whisper layer distribution to member functions
* Replace whisper policy usage with self one
* Fix test case to use non-static layer distribution methods
* fix: fix typo
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path
* trust remote code
* [ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all
* fix and tested ppo
* 2 nd round refactor
* add ci tests
* fix ci
* fix ci
* fix readme, style
* fix readme style
* fix style, fix benchmark
* reproduce benchmark result, remove useless files
* rename to ColossalChat
* use new image
* fix ci workflow
* fix ci
* use local model/tokenizer for ci tests
* fix ci
* fix ci
* fix ci
* fix ci timeout
* fix rm progress bar. fix ci timeout
* fix ci
* fix ci typo
* remove 3d plugin from ci temporary
* test environment
* cannot save optimizer
* support chat template
* fix readme
* fix path
* test ci locally
* restore build_or_pr
* fix ci data path
* fix benchmark
* fix ci, move ci tests to 3080, disable fast tokenizer
* move ci to 85
* support flash attention 2
* add all-in-one data preparation script. Fix colossal-llama2-chat chat template
* add hardware requirements
* move ci test data
* fix save_model, add unwrap
* fix missing bos
* fix missing bos; support grad accumulation with gemini
* fix ci
* fix ci
* fix ci
* fix llama2 chat template config
* debug sft
* debug sft
* fix colossalai version requirement
* fix ci
* add sanity check to prevent NaN loss
* fix requirements
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* update readme
* update readme
* update readme and ignore
* fix logger bug
* support parallel_output
* modify data preparation logic
* fix tokenization
* update lr
* fix inference
* run pre-commit
---------
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`
* feat: apply `GradientCheckpointConfig` to policy and llama_forward
* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager
* fix: add optional args for `distribute_layer` and `get_stage_index`
* fix: fix changed API calls
* test: update llama tests
* style: polish `GradientCheckpointConfig`
* fix: fix pipeline utils tests
* fix incorrect sharding without zero (#5545)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [fix] fix typo s/muiti-node /multi-node etc. (#5448)
* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
* [devops] remove post commit ci (#5566)
* [devops] remove post commit ci
* [misc] run pre-commit on all files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [doc] fix ColossalMoE readme (#5599)
* fix readme
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [zero] support multiple (partial) backward passes (#5596)
* [zero] support multiple (partial) backward passes
* [misc] update requirements
* [shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [CI] run pre-commit (#5577)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
* run pre-commit
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [rebase] rebase main to resize-embedding (#5581)
* [release] grok-1 314b inference (#5490)
* [release] grok-1 inference
* [release] grok-1 inference
* [release] grok-1 inference
* [example] update Grok-1 inference (#5495)
* revise grok-1 example
* remove unused arg in scripts
* prevent re-installing torch
* update readme
* revert modifying colossalai requirements
* add perf
* trivial
* add tokenizer url
* [hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch
* fix: use return_outputs=False to eliminate extra memory consumption
* feat: add return_outputs warning
* style: remove `return_outputs=False` as it is the default value
* [release] grok-1 inference benchmark (#5500)
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [shardformer]Fix lm parallel. (#5480)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* fix lm forward distribution
* fix
* test ci
* fix
* [fix] fix grok-1 example typo (#5506)
* [devops] fix example test ci (#5504)
* Fix ColoTensorSpec for py11 (#5440)
* fixed layout converter caching and updated tester
* Empty-Commit
* [shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)
* [extension] update api
* [feature] add colo attention
* [feature] update sdpa
* [feature] update npu attention
* [feature] update flash-attn
* [test] add flash attn test
* [test] update flash attn test
* [shardformer] update modeling to fit colo attention (#5465)
* [misc] refactor folder structure
* [shardformer] update llama flash-attn
* [shardformer] fix llama policy
* [devops] update tensornvme install
* [test] update llama test
* [shardformer] update colo attn kernel dispatch
* [shardformer] update blip2
* [shardformer] update chatglm
* [shardformer] update gpt2
* [shardformer] update gptj
* [shardformer] update opt
* [shardformer] update vit
* [shardformer] update colo attention mask prep
* [shardformer] update whisper
* [test] fix shardformer tests (#5514)
* [test] fix shardformer tests
* [test] fix shardformer tests
* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>
* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
* Change static methods for t5 layer distribution to member functions
* Change static methods for whisper layer distribution to member functions
* Replace whisper policy usage with self one
* Fix test case to use non-static layer distribution methods
* fix: fix typo
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path
* trust remote code
* [ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all
* fix and tested ppo
* 2 nd round refactor
* add ci tests
* fix ci
* fix ci
* fix readme, style
* fix readme style
* fix style, fix benchmark
* reproduce benchmark result, remove useless files
* rename to ColossalChat
* use new image
* fix ci workflow
* fix ci
* use local model/tokenizer for ci tests
* fix ci
* fix ci
* fix ci
* fix ci timeout
* fix rm progress bar. fix ci timeout
* fix ci
* fix ci typo
* remove 3d plugin from ci temporary
* test environment
* cannot save optimizer
* support chat template
* fix readme
* fix path
* test ci locally
* restore build_or_pr
* fix ci data path
* fix benchmark
* fix ci, move ci tests to 3080, disable fast tokenizer
* move ci to 85
* support flash attention 2
* add all-in-one data preparation script. Fix colossal-llama2-chat chat template
* add hardware requirements
* move ci test data
* fix save_model, add unwrap
* fix missing bos
* fix missing bos; support grad accumulation with gemini
* fix ci
* fix ci
* fix ci
* fix llama2 chat template config
* debug sft
* debug sft
* fix colossalai version requirement
* fix ci
* add sanity check to prevent NaN loss
* fix requirements
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* update readme
* update readme
* update readme and ignore
* fix logger bug
* support parallel_output
* modify data preparation logic
* fix tokenization
* update lr
* fix inference
* run pre-commit
---------
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`
* feat: apply `GradientCheckpointConfig` to policy and llama_forward
* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager
* fix: add optional args for `distribute_layer` and `get_stage_index`
* fix: fix changed API calls
* test: update llama tests
* style: polish `GradientCheckpointConfig`
* fix: fix pipeline utils tests
* fix incorrect sharding without zero (#5545)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [fix] fix typo s/muiti-node /multi-node etc. (#5448)
* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
* [devops] remove post commit ci (#5566)
* [devops] remove post commit ci
* [misc] run pre-commit on all files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [shardformer]enable padding vocabulary size. (#5489)
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* padding vocab
* padding vocabe
* fix
* fix
* fxi
* test ci
* fix
fix
fix
fix
* fix
fix
* fix
* fix
* Update hybrid_parallel_plugin.py
fix
fix
fix
* fix
fix
* fix
fix
* fix
* resolve super init
resolve super init
resolve super init
resolve super init
* resolve comments
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* vocab checkpointio
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
fix
fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* padding vocab
* fix
* fix
fix
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* cherry-pick
* revert moe modify
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
fix
fix
fix
fix
fix
fix
fix
* resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
* ptensor
ptensor
resolve comments
fix
fix
fix
fix
fix
resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix rebase
* fix rebase
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)
* fix no pad token bug
* fixed some auto parallel codegen bug, but might not run on torch 2.1
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] fix pipeline grad ckpt (#5620)
* [shardformer] fix pipeline grad ckpt
* [lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement
* add liscense and implement apis
* add checkpointio apis
* add torchddp fwd_bwd test
* add support_lora methods
* add checkpointio test and debug
* delete unneeded codes
* remove peft from LICENSE
* add concrete methods for enable_lora
* simplify enable_lora api
* fix requirements
* [LowLevelZero] low level zero support lora (#5153)
* low level zero support lora
low level zero support lora
* add checkpoint test
* add checkpoint test
* fix
* fix
* fix
* fix
fix
fix
fix
* fix
* fix
fix
fix
fix
fix
fix
fix
* fix
* fix
fix
fix
fix
fix
fix
fix
* fix
* test ci
* git # This is a combination of 3 commits.
Update low_level_zero_plugin.py
Update low_level_zero_plugin.py
fix
fix
fix
* fix naming
fix naming
fix naming
fix
* [feature] qlora support
* qlora follow commit
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* migrate qutization folder to colossalai/
* minor fixes
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* gptj sp fix
* remove redundancies from pre-commit
* minor fixes
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: hugo-syn <hugo.vincent@synacktiv.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Jun Gao <imgaojun@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Elsa Granger <zeyugao@outlook.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
Co-authored-by: BlueRum <70618399+ht-zhou@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: JIMMY ZHAO <knightyzhao@gmail.com>
Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
Co-authored-by: 李文军 <40464906+liwenjuna@users.noreply.github.com>
Co-authored-by: yixiaoer <miyaku@yixiaoer.sg>
Co-authored-by: CZYCW <czyczf@163.com>
Co-authored-by: Stephan Kölker <stephankoe@users.noreply.github.com>
Co-authored-by: QinLuo <eric.x.sun@gmail.com>
Co-authored-by: MickeyCHAN <76671016+danyow-cheung@users.noreply.github.com>
Co-authored-by: Luo Yihang <luo_yihang@outlook.com>
Co-authored-by: Dongruixuan Li <dongruixuan@hotmail.com>
Co-authored-by: hugo-syn <61210734+hugo-syn@users.noreply.github.com>
Co-authored-by: Youngon <Youngon_wyl@163.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-23 09:57:44 +00:00
|
|
|
import math
|
|
|
|
from abc import ABC
|
|
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
from colossalai.moe._operation import moe_cumsum
|
|
|
|
from colossalai.moe.manager import MOE_MANAGER
|
|
|
|
|
|
|
|
|
|
|
|
class MoeRouter(nn.Module, ABC):
|
|
|
|
"""Base class for all MoE routers.
|
|
|
|
Args:
|
|
|
|
k_value (int): The value of top_k.
|
|
|
|
capacity_factor_train (float): Capacity factor in routing of training.
|
|
|
|
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
|
|
|
min_capacity (int): The minimum number of the capacity of each expert.
|
|
|
|
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
k_value: int,
|
|
|
|
capacity_factor_train: float,
|
|
|
|
capacity_factor_eval: float,
|
|
|
|
min_capacity: int,
|
|
|
|
noisy_func: Optional[Callable] = None,
|
|
|
|
drop_tks: bool = True,
|
|
|
|
use_kernel: bool = False,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.k_value = k_value
|
|
|
|
self.capacity_factor_train = capacity_factor_train
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval
|
|
|
|
self.min_capacity = min_capacity
|
|
|
|
self.noisy_func = noisy_func
|
|
|
|
self.drop_tks = drop_tks
|
|
|
|
self._aux_loss = None
|
|
|
|
self._z_loss = None
|
|
|
|
self.use_kernel = use_kernel
|
|
|
|
|
|
|
|
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
|
|
|
if ep_group is not None:
|
|
|
|
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
|
|
|
|
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
|
|
|
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
|
|
|
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
|
|
|
capacity += capacity % 2
|
|
|
|
capacity = max(capacity, self.min_capacity)
|
|
|
|
assert capacity > 0
|
|
|
|
return int(capacity)
|
|
|
|
|
|
|
|
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
|
|
|
|
"""Computes auxiliary load balancing loss as in Switch Transformer.
|
|
|
|
|
|
|
|
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
|
|
|
|
implements the loss function presented in equations (4) - (6). It aims to
|
|
|
|
penalize those cases where the routing between experts is unbalanced.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
router_probs: Probability assigned to each expert per token. Shape:
|
|
|
|
<float32>[num_groups, tokens_per_group, num_experts].
|
|
|
|
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
|
|
|
|
indices identifying the top num_selected_experts for a given token.
|
|
|
|
"""
|
|
|
|
assert self._aux_loss is None
|
|
|
|
if router_probs.dim() == expert_indices.dim() == 2:
|
|
|
|
router_probs = router_probs.unsqueeze(0)
|
|
|
|
expert_indices = expert_indices.unsqueeze(0)
|
|
|
|
assert (
|
|
|
|
router_probs.dim() == expert_indices.dim() == 3
|
|
|
|
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
|
|
|
|
|
|
|
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
|
|
|
expert_mask = F.one_hot(expert_indices, num_experts)
|
|
|
|
# For a given token, determine if it was routed to a given expert.
|
|
|
|
# Shape: [num_groups, tokens_per_group, num_experts]
|
|
|
|
expert_mask = expert_mask.max(dim=-2)[0]
|
|
|
|
|
|
|
|
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
|
|
|
|
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
|
|
|
|
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
|
|
|
|
self._aux_loss = aux_loss
|
|
|
|
|
|
|
|
def set_z_loss(self, router_logits: torch.Tensor):
|
|
|
|
"""Compute router z-loss.
|
|
|
|
|
|
|
|
The router z-loss was introduced in Designing Effective Sparse Expert Models
|
|
|
|
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
|
|
|
|
small in an effort to improve stability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
|
|
|
|
"""
|
|
|
|
assert self._z_loss is None
|
|
|
|
if router_logits.dim() == 2:
|
|
|
|
router_logits = router_logits.unsqueeze(0)
|
|
|
|
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
|
|
|
|
num_groups, tokens_per_group, _ = router_logits.shape
|
|
|
|
log_z = torch.logsumexp(router_logits, dim=-1)
|
|
|
|
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
|
|
|
|
self._z_loss = z_loss
|
|
|
|
|
|
|
|
def pop_router_loss(self) -> torch.Tensor:
|
|
|
|
assert self._aux_loss is not None
|
|
|
|
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
|
|
|
|
self._aux_loss = None
|
|
|
|
self._z_loss = None
|
|
|
|
|
|
|
|
|
|
|
|
class Top1Router(MoeRouter):
|
|
|
|
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
|
|
|
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
|
|
|
function can be found in the paper about Switch Transformer of Google.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
|
|
|
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
|
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
|
|
|
select_policy (str, optional): The policy about tokens selection.
|
|
|
|
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
select_policy: str = "first",
|
|
|
|
noisy_func: Optional[Callable] = None,
|
|
|
|
drop_tks: bool = True,
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
k_value=1,
|
|
|
|
capacity_factor_train=capacity_factor_train,
|
|
|
|
capacity_factor_eval=capacity_factor_eval,
|
|
|
|
min_capacity=min_capacity,
|
|
|
|
noisy_func=noisy_func,
|
|
|
|
drop_tks=drop_tks,
|
|
|
|
)
|
|
|
|
self.select_policy = select_policy
|
|
|
|
assert select_policy in {"first", "random"}
|
|
|
|
if select_policy == "random":
|
|
|
|
self.uniform = torch.distributions.uniform.Uniform(
|
|
|
|
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
|
|
|
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
|
|
|
).rsample
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
use_kernel: bool = False,
|
|
|
|
ep_group: Optional[ProcessGroup] = None,
|
|
|
|
use_loss: bool = False,
|
|
|
|
use_norm: bool = False,
|
|
|
|
) -> Tuple:
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
1. use_kernel is False:
|
|
|
|
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
|
|
|
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
|
|
|
2. use_kernel is True:
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
if self.noisy_func is not None and self.training:
|
|
|
|
inputs = self.noisy_func(inputs)
|
|
|
|
|
|
|
|
assert inputs.dtype == torch.float
|
|
|
|
probs = F.softmax(inputs, dim=-1)
|
|
|
|
num_experts = probs.size(-1)
|
|
|
|
num_tokens = inputs.size(0)
|
|
|
|
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
|
|
|
|
|
|
|
top1_idx = torch.argmax(inputs, dim=-1)
|
|
|
|
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
|
|
|
|
# calculate router loss
|
|
|
|
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
|
|
|
|
self.set_z_loss(inputs)
|
|
|
|
self.pop_router_loss()
|
|
|
|
|
|
|
|
if not self.training and not self.drop_tks and ep_group is not None:
|
|
|
|
max_num = torch.max(torch.sum(mask, dim=0))
|
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
|
|
|
capacity = max_num.item()
|
|
|
|
|
|
|
|
if self.select_policy == "random":
|
|
|
|
rand_mask = mask * self.uniform(mask.shape)
|
|
|
|
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
|
|
|
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
|
|
|
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
|
|
|
elif self.select_policy == "first":
|
|
|
|
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
|
|
|
mask = mask * torch.lt(ranks, capacity)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Not support such select policy yet.")
|
|
|
|
|
|
|
|
ranks = torch.sum(mask * ranks, dim=-1)
|
|
|
|
used_capacity = mask.sum(dim=0)
|
|
|
|
|
|
|
|
if use_kernel:
|
|
|
|
mask = torch.sum(mask, dim=-1)
|
|
|
|
mask = torch.stack([mask], dim=0).to(torch.int32)
|
|
|
|
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
|
|
|
return used_capacity, probs, mask, dest_idx, num_experts * capacity
|
|
|
|
else:
|
|
|
|
ranks = F.one_hot(ranks, num_classes=capacity)
|
|
|
|
weight = mask * probs.type_as(inputs)
|
|
|
|
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
|
|
|
sec_mask = combine_weights.bool()
|
|
|
|
return used_capacity, combine_weights, sec_mask, probs
|
|
|
|
|
|
|
|
|
|
|
|
class Top2Router(MoeRouter):
|
|
|
|
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
|
|
|
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
|
|
|
function can be found in the paper about ViT-MoE.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
|
|
|
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
|
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert
|
|
|
|
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
noisy_func: Optional[Callable] = None,
|
|
|
|
drop_tks: bool = True,
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
k_value=2,
|
|
|
|
capacity_factor_train=capacity_factor_train,
|
|
|
|
capacity_factor_eval=capacity_factor_eval,
|
|
|
|
min_capacity=min_capacity,
|
|
|
|
noisy_func=noisy_func,
|
|
|
|
drop_tks=drop_tks,
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
use_kernel: bool = False,
|
|
|
|
ep_group: Optional[ProcessGroup] = None,
|
|
|
|
use_norm: bool = False,
|
|
|
|
use_loss: bool = True,
|
|
|
|
) -> Tuple:
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
1. use_kernel is False:
|
|
|
|
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
|
|
|
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
|
|
|
2. use_kernel is True:
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
if self.noisy_func is not None and self.training:
|
|
|
|
inputs = self.noisy_func(inputs)
|
|
|
|
|
|
|
|
assert inputs.dtype == torch.float
|
|
|
|
probs = F.softmax(inputs, dim=-1)
|
|
|
|
if use_norm:
|
|
|
|
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
|
|
|
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
num_experts = probs.size(-1)
|
|
|
|
num_tokens = inputs.size(0)
|
|
|
|
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
|
|
|
|
|
|
|
top1_idx = torch.argmax(probs, dim=-1)
|
|
|
|
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
|
|
|
|
top2_idx = torch.argmax(logits_except1, dim=-1)
|
|
|
|
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
|
|
|
|
cmask = mask1 + mask2 # loss: [s, e]
|
|
|
|
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
|
|
|
|
|
|
|
# calculate loss
|
|
|
|
if use_loss:
|
|
|
|
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
|
|
|
self.set_aux_loss(probs, expert_indices, num_experts)
|
|
|
|
self.set_z_loss(inputs)
|
|
|
|
self.pop_router_loss()
|
|
|
|
|
|
|
|
if not self.training and not self.drop_tks and ep_group is not None:
|
|
|
|
max_num = torch.max(torch.sum(cmask, dim=0))
|
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
|
|
|
capacity = max_num.item()
|
|
|
|
|
|
|
|
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
|
|
|
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
|
|
|
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
|
|
|
|
|
|
|
mask1 *= torch.lt(rank1, capacity)
|
|
|
|
mask2 *= torch.lt(rank2, capacity)
|
|
|
|
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
|
|
|
|
|
|
|
|
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
|
|
|
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
|
|
|
|
|
|
|
if use_kernel:
|
|
|
|
mask1 = torch.sum(mask1, dim=-1)
|
|
|
|
mask2 = torch.sum(mask2, dim=-1)
|
|
|
|
|
|
|
|
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
|
|
|
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
|
|
|
|
|
|
|
return used_capacity, probs, mask, dest_idx, num_experts * capacity
|
|
|
|
else:
|
|
|
|
"""
|
|
|
|
The following code is equivalent to:
|
|
|
|
|
|
|
|
```
|
|
|
|
weight1 = mask1 * probs.type_as(inputs)
|
|
|
|
weight2 = mask2 * probs.type_as(inputs)
|
|
|
|
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
|
|
|
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
|
|
|
|
|
|
|
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
|
|
|
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
|
|
|
cb_weight = cb_weight1 + cb_weight2
|
|
|
|
sec_mask = cb_weight.bool()
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
weight1 = mask1 * probs.type_as(inputs)
|
|
|
|
weight2 = mask2 * probs.type_as(inputs)
|
|
|
|
|
|
|
|
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
|
|
|
|
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
|
|
|
|
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
|
|
|
|
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
|
|
|
|
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
|
|
|
|
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
|
|
|
|
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
|
|
|
|
|
|
|
|
return used_capacity, cb_weight, sec_mask
|
|
|
|
|
|
|
|
|
|
|
|
class TopKRouter(MoeRouter):
|
|
|
|
"""Masked matmul router using tokens choose top-k experts assignment.
|
|
|
|
|
|
|
|
NOTE: this is modified from flaxformer.
|
|
|
|
This router uses the same mechanism as in Switch Transformer
|
|
|
|
(https://arxiv.org/abs/2101.03961) and V-MoE
|
|
|
|
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
|
|
|
|
sorted by router_probs and then routed to their choice of expert until the
|
|
|
|
expert's expert_capacity is reached. There is no guarantee that each token is
|
|
|
|
processed by an expert, or that each expert receives at least one token.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
num_selected_experts: Maximum number of experts to which each token is
|
|
|
|
routed. Tokens may be routed to fewer experts if particular experts are
|
|
|
|
oversubscribed / reach capacity.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_selected_experts: int,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
noisy_func: Optional[Callable] = None,
|
|
|
|
drop_tks: bool = True,
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
router_probs: torch.Tensor,
|
|
|
|
expert_capacity: int,
|
|
|
|
) -> Tuple:
|
|
|
|
"""Computes masks for the top-k experts per token.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
|
|
|
|
probabilities used to determine the routing of tokens to the experts.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dispatch and combine arrays for routing with masked matmuls.
|
|
|
|
"""
|
|
|
|
# TODO: FIXME: add parallel group
|
|
|
|
num_groups, _, num_experts = router_probs.shape
|
|
|
|
|
|
|
|
# Top-k router probability and corresponding expert indices for each token.
|
|
|
|
# Shape: [num_groups, tokens_per_group, num_selected_experts].
|
|
|
|
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
|
|
|
|
|
|
|
|
self.set_aux_loss(router_probs, expert_index, num_experts)
|
|
|
|
self.pop_router_loss()
|
|
|
|
|
|
|
|
# Make num_selected_experts the leading axis to ensure that top-1 choices
|
|
|
|
# have priority over top-2 choices, which have priority over top-3 choices,
|
|
|
|
# etc.
|
|
|
|
expert_index = torch.transpose(expert_index, 1, 2)
|
|
|
|
# Shape: [num_groups, num_selected_experts * tokens_per_group]
|
|
|
|
expert_index = expert_index.reshape(num_groups, -1)
|
|
|
|
|
|
|
|
# Create mask out of indices.
|
|
|
|
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
|
|
|
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
|
|
|
|
|
|
|
# Experts have a fixed capacity that we cannot exceed. A token's priority
|
|
|
|
# within the expert's buffer is given by the masked, cumulative capacity of
|
|
|
|
# its target expert.
|
|
|
|
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
|
|
|
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
|
|
|
|
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
|
|
|
|
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
|
|
|
|
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
|
|
|
token_priority = torch.transpose(token_priority, 1, 2)
|
|
|
|
# For each token, across all selected experts, select the only non-negative
|
|
|
|
# (unmasked) priority. Now, for group G routing to expert E, token T has
|
|
|
|
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
|
|
|
|
# is its targeted expert.
|
|
|
|
# Shape: [num_groups, tokens_per_group, num_experts].
|
|
|
|
token_priority = torch.max(token_priority, dim=2)[0]
|
|
|
|
|
|
|
|
# Token T can only be routed to expert E if its priority is positive and
|
|
|
|
# less than the expert capacity. One-hot matrix will ignore indices outside
|
|
|
|
# the range [0, expert_capacity).
|
|
|
|
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
|
|
|
|
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
|
|
|
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
|
|
|
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
|
|
|
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
|
|
|
|
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
|
|
|
|
|
|
|
# The combine array will be used for combining expert outputs, scaled by the
|
|
|
|
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
|
|
|
# expert_capacity].
|
|
|
|
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
|
|
|
|
|
|
|
return combine_array, dispatch_mask
|
|
|
|
|
|
|
|
|
|
|
|
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
|
|
|
|
if not grouped:
|
|
|
|
if top_k == 1:
|
|
|
|
return Top1Router
|
|
|
|
elif top_k == 2:
|
|
|
|
return Top2Router
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("top_k > 2 is not supported yet")
|
|
|
|
else:
|
|
|
|
return TopKRouter
|