eea37da6fa
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit
|
||
---|---|---|
.. | ||
csrc | ||
pybind | ||
README.md | ||
__init__.py | ||
base_extension.py | ||
cpp_extension.py | ||
cuda_extension.py | ||
triton_extension.py | ||
utils.py |
README.md
🔌 Extensions
📌 Table of Contents
📚 Introduction
This module is a designed to offer extensions to the existing ColossalAI framework. It is designed to be a collection of high-performance kernels to speed up the training and inference process. Different from writing an individual kernel, the extensions
module offers a layer of abstraction to collate kernels written in different compiler backends and for different hardware backends in an organized way. Please see the design and usage in the sections below.
🪅 Design
The extensions
module is a sub-module of the colossalai.kernel
module. This module is put at the project root directory so that it can be imported for AOT (ahead-of-time) build. At the same time, it is symbolically linked at the colossalai.kernel.extensions
path for runtime build.
As we want to support multi-backend kernels, we have to consider multiple compiler options such as torch.jit
, CUDA
, triton
and multiple hardware backends such as CPU
, GPU
and NPU
. To make it easy for the users, we have abstract away the kernels into extensions and expose a single loader to the user for each kind of kernel.
For example, if the user wants to use the CPU Adam kernel, he can just call load()
on the kernel loader. The kernel loader will automatically select the correct extension based on the current hardware and compiler backend. The user does not need to worry about the details of the kernel implementation. For example, if the user is using ARM CPU, then Arm kernel will be built and loaded. If it is a X86 CPU, then it is the X86 kernel that will be loaded.
from colossalai.kernel.kernel_loader import CPUAdamLoader
# load the kernel compatible with the current hardware
kernel = CPUAdamLoader().load()
🛠 API Usage
To make the colossalai.kernel
easy to use, we expose some simple APIs and you can use them based on your scenario.
- Case 1: Simply load a kernel
from colossalai.kernel.kernel_loader import CPUAdamLoader
# load the kernel compatible with the current hardware
kernel = CPUAdamLoader().load()
- Case 2: Load a specific kernel
This case applies if you are familiar with the extensions available.
from colossalai.kernel.kernel_loader import CPUAdamLoader
# load the kernel by giving the kernel name
kernel = CPUAdamLoader().load(ext_name="cpu_adam_arm")
- Case 3: Register your own extension
This case applies if you know how to write an extension. If you do not know how, you can refer to the section below.
from colossalai.kernel.kernel_loader import CPUAdamLoader
from colossalai.kernel.base_extension import _Extension
# create your own extension class
class MyExtension(_Extension):
def __init__(self):
self._name = "my_extension"
self._support_aot = True
self._support_jit = True
self.priority = 10
# implementation here
...
# register your extension
# you can use the priority value to make sure your kernel will be loaded by default
CPUAdamLoader.register_extension(MyExtension)
# load the kernel
kernel = CPUAdamLoader().load()
🏗 Write a customized extension
It is easy to write a customized extension. If you have experience writing CUDA/triton kernels, you should get familiar with the process quickly.
You just need to inherit the _Extension
base class or other backend-specific classes such as _CudaExtension
and implement the abstract methods. Then, you need to register your extension to the kernel loader based on the Case 3 above. The kernel loader will automatically select the correct extension based on the priority score, current hardware, compiler backend.
from colossalai.kernel.base_extension import _Extension
class MyExtension(_Extension):
def __init__(self):
self._name = "my_extension"
self._support_aot = True
self._support_jit = True
self.priority = 10
def is_available(self) -> bool:
"""
Return if the required hardware can be found.
"""
...
def assert_compatible(self) -> None:
"""
Check if the hardware required by the kernel is compatible.
"""
...
def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
"""
If this kernel can be built AOT, it should return an extension object
to Python setuptools for compilation.
"""
...
def build_jit(self) -> Callable:
"""
Build extension kernel just in time.
"""
...
def load(self):
"""
The API called by the user to get the kernel.
"""
...
✏️ Acknowledgement
This module is written from scratch but we learnt a lot by looking into DeepSpeed' s op_builder. We wish to acknowledge their great work and contributions to the open-source community.