![]() * [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> |
||
---|---|---|
.. | ||
examples | ||
layer | ||
modeling | ||
policies | ||
shard | ||
README.md | ||
__init__.py | ||
_utils.py |
README.md
âĄī¸ ShardFormer
đ Table of Contents
- âĄī¸ ShardFormer
đ Introduction
Shardformer is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
đ¨ Usage
Quick Start
The sample API usage is given below(If you enable the use of flash attention, please install flash_attn
. In addition, xformers's cutlass_op
provide a supplementary optimization):
from colossalai.shardformer import ShardConfig, ShardFormer
from transformers import BertForMaskedLM
import colossalai
# launch colossalai
colossalai.launch_from_torch(config={})
# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
# create huggingface model as normal
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=True,
enable_fused_normalization=True,
enable_flash_attention=True,
enable_jit_fused=True,
enable_sequence_parallelism=True,
enable_sequence_overlap=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model).to('cuda')
# do everything like normal
...
Following are the description ShardConfig
's arguments:
-
tensor_parallel_process_group
: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. -
pipeline_stage_manager
: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. -
enable_tensor_parallelism
: Whether to use tensor parallelism. Defaults to True. -
enable_fused_normalization
: Whether to use fused layernorm. Defaults to False. -
enable_flash_attention
: Whether to switch on flash attention. Defaults to False. -
enable_jit_fused
: Whether to switch on JIT fused operators. Defaults to False. -
enable_sequence_parallelism
: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. -
enable_sequence_overlap
: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used whenenable_sequence_parallelism
is True. Defaults to False. -
enable_all_optimization
: Whether to turn on all optimization tools includingfused normalization
,flash attention
,JIT fused operators
,sequence parallelism
andsequence overlap
. Defaults to False. -
extra_kwargs
: A dict to store extra kwargs for ShardFormer.
Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in API Design.
from colossalai.shardformer import Policy
class MyPolicy(Policy):
# implement your own policy
...
# init model and shard former
...
# use customized policy to shard model
my_policy = MyPolicy()
shard_former.optimize(model, my_policy)
đē Roadmap
We will follow this roadmap to develop Shardformer:
- API Design
- API Implementation
- Unit Testing
- Policy Implementation
model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
---|---|---|---|---|---|---|---|---|---|
bert | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] |
t5 | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [ ] | [ ] |
llama V1/V2 | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [ ] | [ ] |
gpt2 | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] |
opt | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [ ] | [ ] |
bloom | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] |
chatglm2 | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] | [â] |
vit | [â] | [â] | [ ] | [â] | [â] | [â] | [â] | [ ] | [ ] |
whisper | [â] | [â] | [â] | [â] | [â] | [ ] | [â] | [ ] | [ ] |
sam | [â] | [ ] | [ ] | [â] | [â] | [â] | [â] | [ ] | [ ] |
blip2 | [â] | [ ] | [ ] | [â] | [â] | [â] | [â] | [ ] | [ ] |
falcon | [â] | [â] | [â] | [â] | [â] | [ ] | [â] | [ ] | [ ] |
roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
mistral | [â] | [ ] | [ ] | [â] | [â] | [â] | [â] | [ ] | [ ] |
đĄ API Design
We will discuss the major components of ShardFormer
below to help you better understand how things work.
This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation.
Please refer to the code for more details.
Distributed Modules
ShardFormer
replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new forward
function to execute distributed computation.
Each distributed module implements its from_native_module
static method to convert the PyTorch module to its corresponding distributed module.
class ParallelModule(torch.nn.Module):
@abstractmethod
def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule
"""
Convert a native module to a parallelized
Examples:
```python
# replace module
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
"""
Shard Config
ShardConfig
is a simple data class to tell ShardFormer
how sharding will be performed.
@dataclass
class ShardConfig:
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
...
# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
use_flash_attention: bool # whether to use flash attention to speed up attention
extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
Policy
The Policy
class describes how to handle the model sharding.
It is merely a description, the actual sharding will be performed by ModelSharder
.
We abstract the policy into four stages:
- Preprocessing: call
Policy.preprocess
to do some prior work before sharding, for example, resizing the embedding - Providing
ModulePolicyDescription
: callPolicy.module_policy
to get a bunch ofModulePolicyDescription
to tellModelSharder
how the submodules's attributes, child parameters, and deeper submodules will be substituted. - Postprocessing: call
Policy.postprocess
to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
@dataclass
class ModulePolicyDescription:
r"""
Describe how the attributes and parameters will be transformed in a policy.
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None
method_replacement: Dict[str, Callable] = None
@dataclass
class SubModuleReplacementDescription:
r"""
Describe how a submodule will be replaced
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
class Policy(ABC):
r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self)
self.model = None
def set_model(self, model: nn.Module) -> None:
"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
"""
self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod
def preprocess(self) -> nn.Module:
"""
Perform some preprocessing on the model, such as resizing the embedding size
"""
...
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
"""
...
@abstractmethods
def postprocess(self) -> nn.Module:
"""
Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
"""
...
Model Sharder
ModelSharder
is the class in charge of sharding the model based on the given policy.
class ModelSharder:
def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
#TODO: input is a cls or a obj
...
def shard(self) -> None:
"""
Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...
def replace_module(self) -> None:
"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
"""
...
User-facing API
We only expose a limited number of APIs to the user to keep their user experience simple and clean.
class ShardFormer:
"""
Parallelize model based on the given config and policy
Example:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
model, shared_params = shard_former.optimize(org_model)
"""
def __init__(self, shard_config: ShardConfig):
"""
Do two things:
1. Create a distribute coordinator
2. serve as a store for shard config
"""
self.shard_config = shard_config
self.coordinator = DistCoordinator()
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
r"""
This method will optimize the model based on the given policy.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
Returns: the sharded model and the shared parameters
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
shared_params = sharder.shard()
return model, shared_params
â¨ī¸ Development Notes
Add New Policy to Shardformer
This section serves as the guideline for writing new policies and register them into shardformer
.
- Step 1. Write your own model policy
You can create a new file in the colossalai/shardformer/policies
folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as transformers
should be imported only in the function body when needed.
Please follow the following protocols when writing your policy:
-
You have to make a clear decision what you want to replace exactly in the original PyTorch module
- Use
ModulePolicyDescription.attribute_replacement
to replace the module attributes - Use
ModulePolicyDescription.param_replacement
to replace the module parameters - Use
ModulePolicyDescription.sub_module_replacement
to replace the submodules completely. The target module should implement thefrom_native_module
for the replacement. - Use
ModulePolicyDescription.method_replacement
to replace the module methods. These replacement methods should be put in theshardformer/modeling/<model-name>.py
.
- Use
-
You can implement the
ParallelModule
for primitive modules in theshardformer/layer/<model-name>.py
file. Primitive modules refer to modules which are not composed of other modules. For example, thetorch.nn.Linear
module is a primitive module while modules such asBertEncoder
module in thetransformers
library is a composite module. Primitive modules do not nested innernn.Module
members. For composite modules, you should consider usingModulePolicyDescription
to implement your replacement. -
ParallelModule
is meant to be used in two ways:ParallelModule.from_native_module
to convert native PyTorch module to theParallelModule
andParallelModule(...)
to instantiate the module directly just like a normal PyTorch module.ParallelModule
should be only implemented for modules whose weights are sharded. If you want to make your module compatible with theModulePolicyDescription.sub_module_replacement
and there is no weight sharding in your module, you can just implement thefrom_native_module
method without inheriting theParallelModule
likecolossalai/shardformer/layer/normalization.py
. -
Do not import any file in the
colossalai/shardformer/policies
andcolossalai/shardformer/modeling
to avoid unwanted import error. For example, a file in these folders accidentally importstransformers
library at the top of the file, then the user will have to installtransformers
library even if they do not use this file. Any file in themodeling
folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via theShardFormer
module. -
Try to keep your import statement on third-party libraries such as
transformers
within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy. -
Step 2. Register your policy to the autopolicy
Next, you need to register your policy in the colossalai/shardformer/policies/autopolicy.py
file.
For example, if we register the policy for the BERT model, we just add a key-value in the _POLICY_LIST
dictionary. The key if the qualname
of the model object (you can get it by model.__class__.__qualname__). The value is a PolicyLocation
object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as transformers
) which we do not want to import when we do not use the policy.
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
}
Write Your Unit Testing
This section serves as the guideline for testing the shardformer
module.
- Step 1. Add your model to the model zoo in the test kits.
Add your model to the tests/kit/model_zoo
file. This allows you to define test-related components for this model. You can take tests/kit/model_zoo/transformers/llama.py
as an example for reference.
- Step 2. Write your unit testing for the model
Next, implement your unit test in the tests/test_shardformer
folder. Please refer to other similar tests for style consistency.
- Step 3. Execute your test
When you run tests locally, you should run tests for both your newly-added test file and the whole shardformer
module tests.
# test for your own test file
pytest tests/test_shardformer/test_model/<your-file>.py
# test for the whole shardformer module
pytest tests/test_shardformer
đ Benchmarking
System Performance
We conducted benchmark tests to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
N_CTX | org_model | shard_model |
---|---|---|
256 | 11.2ms | 17.2ms |
512 | 9.8ms | 19.5ms |
1024 | 19.6ms | 18.9ms |
2048 | 46.6ms | 30.8ms |
4096 | 160.5ms | 90.4ms |
In the case of using 4 GPUs, the training times are as follows.
N_CTX | org_model | shard_model |
---|---|---|
256 | 10.0ms | 21.1ms |
512 | 11.5ms | 20.2ms |
1024 | 22.1ms | 20.6ms |
2048 | 46.9ms | 24.8ms |
4096 | 160.4ms | 68.0ms |
As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
Convergence
To validate that training the model using shardformers does not impact its convergence. We fine-tuned the BERT model using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
the configurations are as follows:
batch_size = 2
epoch = 3
lr = 2.4e-5
accumulation_steps = 8
warmup_fraction = 0.03
accuracy | f1 | loss | GPU number | model sharded |
---|---|---|---|---|
0.82971 | 0.87713 | 0.23194 | 4 | True |
0.83797 | 0.88006 | 0.22683 | 2 | True |
0.84521 | 0.88700 | 0.21822 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.