ColossalAI/colossalai/shardformer
flybird11111 576a2f7b10
[gemini] gemini support tensor parallelism. (#4942)
* [colossalai]fix typo

* [inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license

* Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598

* [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test

* [format] applied code formatting on changed files in pull request 4908 (#4918)

Co-authored-by: github-actions <github-actions@github.com>

* [gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case

* [hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit

* [test] add no master test for low level zero plugin (#4934)

* [format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>

* [nfc] fix some typo with colossalai/ docs/ etc. (#4920)

* [Refactor] Integrated some lightllm kernels into token-attention  (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test

* [inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>

* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

* fix

fix

fix

* update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

* support fused layernorm

support fused layernorm

support fused layernorm

* update fusedlayernorm

update fusedlayernorm

update fusedlayernorm

* add sequence parallel to gemini

add sequence parallel to gemini

* fix

* fix comments

fix comments

fix comments

* fix

* fix t5

* clear cache

* fix

* activate ci

* activate ci

* fix

* fix

* fix

* fix

* revert

* modify tp gather method

modify tp gather method

modify tp gather method

modify tp gather method

* fix test

---------

Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
2023-11-10 10:15:16 +08:00
..
examples [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00
layer [gemini] gemini support tensor parallelism. (#4942) 2023-11-10 10:15:16 +08:00
modeling [gemini] gemini support tensor parallelism. (#4942) 2023-11-10 10:15:16 +08:00
policies [gemini] gemini support tensor parallelism. (#4942) 2023-11-10 10:15:16 +08:00
shard [Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014) 2023-11-07 15:01:50 +08:00
README.md [hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926) 2023-11-03 13:32:43 +08:00
__init__.py [shardformer] Refactor shardformer api (#4001) 2023-07-04 16:05:01 +08:00
_utils.py [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00

README.md

ShardFormer

📚 Table of Contents

🔗 Introduction

Shardformer is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.

🔨 Usage

Quick Start

The sample API usage is given below(If you enable the use of flash attention, please install flash_attn. In addition, xformers's cutlass_op provide a supplementary optimization):

from colossalai.shardformer import ShardConfig, ShardFormer
from transformers import BertForMaskedLM
import colossalai

# launch colossalai
colossalai.launch_from_torch(config={})

# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)

# create huggingface model as normal
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
                        pipeline_stage_manager=stage_manager,
                        enable_tensor_parallelism=True,
                        enable_fused_normalization=True,
                        enable_flash_attention=True,
                        enable_jit_fused=True,
                        enable_sequence_parallelism=True,
                        enable_sequence_overlap=True)

shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model).to('cuda')

# do everything like normal
...

Following are the description ShardConfig's arguments:

  • tensor_parallel_process_group: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.

  • pipeline_stage_manager: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.

  • enable_tensor_parallelism: Whether to use tensor parallelism. Defaults to True.

  • enable_fused_normalization: Whether to use fused layernorm. Defaults to False.

  • enable_flash_attention: Whether to switch on flash attention. Defaults to False.

  • enable_jit_fused: Whether to switch on JIT fused operators. Defaults to False.

  • enable_sequence_parallelism: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.

  • enable_sequence_overlap: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.

  • enable_all_optimization: Whether to turn on all optimization tools including fused normalizaion, flash attention, JIT fused operators, sequence parallelism and sequence overlap. Defaults to False.

  • inference_only: Whether only doing forward passing. Defaults to False.

Write your own policy

If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in API Design.

from colossalai.shardformer import Policy

class MyPolicy(Policy):
    # implement your own policy
    ...

# init model and shard former
...

# use customized policy to shard model
my_policy = MyPolicy()
shard_former.optimize(model, my_policy)



🗺 Roadmap

We will follow this roadmap to develop Shardformer:

  • API Design
  • API Implementation
  • Unit Testing
  • Policy Implementation
model tensor parallel pipeline parallel lazy initialization xformer flash attn2 jit fused operator fused layernorm sequence parallel overlap
bert [x] [x] [x] [x] [x] [x] [x] [x] [x]
t5 [x] [x] [x] [x] [x] [x] [x] [ ] [ ]
llama V1/V2 [x] [x] [x] [x] [x] [x] [x] [ ] [ ]
gpt2 [x] [x] [x] [x] [x] [x] [x] [x] [x]
opt [x] [x] [x] [x] [x] [x] [x] [ ] [ ]
bloom [x] [x] [x] [x] [x] [x] [x] [x] [x]
chatglm2 [x] [x] [x] [x] [x] [x] [x] [x] [x]
vit [x] [x] [ ] [x] [x] [x] [x] [ ] [ ]
whisper [x] [x] [x] [x] [x] [ ] [x] [ ] [ ]
sam [x] [ ] [ ] [x] [x] [x] [x] [ ] [ ]
blip2 [x] [ ] [ ] [x] [x] [x] [x] [ ] [ ]
roberta [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
albert [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
ernie [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
gpt-neo [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
gpt-j [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
beit [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
swin [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
swin V2 [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]
qwen [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ]

💡 API Design

We will discuss the major components of ShardFormer below to help you better understand how things work. This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation. Please refer to the code for more details.


Distributed Modules

ShardFormer replaces the original PyTorch module with a distributed module. The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new forward function to execute distributed computation. Each distributed module implements its from_native_module static method to convert the PyTorch module to its corresponding distributed module.

class ParallelModule(torch.nn.Module):

    @abstractmethod
    def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule
        """
        Convert a native module to a parallelized

        Examples:

        ```python
        # replace module
        my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
        ```
        """

Shard Config

ShardConfig is a simple data class to tell ShardFormer how sharding will be performed.

@dataclass
class ShardConfig:
    tensor_parallel_process_group: ProcessGroup = None
    enable_fused_normalization: bool = False
    ...

    # Some possible future config fields
    tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
    inference_only: bool # only inject inference-suitable sharding policy
    use_flash_attention: bool # whether to use flash attention to speed up attention

Policy

The Policy class describes how to handle the model sharding. It is merely a description, the actual sharding will be performed by ModelSharder. We abstract the policy into four stages:

  1. Preprocessing: call Policy.preprocess to do some prior work before sharding, for example, resizing the embedding
  2. Providing ModulePolicyDescription: call Policy.module_policy to get a bunch of ModulePolicyDescription to tell ModelSharder how the submodules's attributes, child parameters, and deeper submodules will be substituted.
  3. Postprocessing: call Policy.postprocess to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
@dataclass
class ModulePolicyDescription:
    r"""
    Describe how the attributes and parameters will be transformed in a policy.

    Args:
        attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
        param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
        sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
                    object which specifies the module to be replaced and the target module used to replacement.
        method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
    """
    attribute_replacement: Dict[str, Any] = None
    param_replacement: List[Callable] = None
    sub_module_replacement: List[SubModuleReplacementDescription] = None
    method_replacement: Dict[str, Callable] = None

@dataclass
class SubModuleReplacementDescription:
    r"""
    Describe how a submodule will be replaced

    Args:
        suffix (str): used to get the submodule object
        target_module (ParallelModule): specifies the module class used to replace to submodule
        kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
        ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
    """
    suffix: str
    target_module: ParallelModule
    kwargs: Dict[str, Any] = None
    ignore_if_not_exist: bool = False


class Policy(ABC):
    r"""
    The base class for all the policies. For each different model, it should have a different policy class,
    like BertPolicy for Bert Model or OPTPolicy for OPT model.

    Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
    built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
    If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
    """

    def __init__(self)
        self.model = None

    def set_model(self, model: nn.Module) -> None:
        """
        Set model as an attribute of the Policy object so that we can access the model's attributes.
        """
        self.model = model

    def set_shard_config(self, shard_config: ShardConfig) -> None:
        r"""
        Set shard config as an attribute of the Policy object.
        Args:
            shard_config (:class:`ShardConfig`): The shard config to be perform
        """
        self.shard_config = shard_config

        self.config_sanity_check()

    @abstractmethod
    def preprocess(self) -> nn.Module:
        """
        Perform some preprocessing on the model, such as resizing the embedding size
        """
        ...

    @abstractmethod
    def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
        """
        Return the dict for the modify policy, the key is the original layer class and the value is the
        argument for the modify layer
        """
        ...

    @abstractmethods
    def postprocess(self) -> nn.Module:
        """
        Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
        """
        ...

Model Sharder

ModelSharder is the class in charge of sharding the model based on the given policy.

class ModelSharder:

    def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
        #TODO: input is a cls or a obj
        ...

    def shard(self) -> None:
        """
        Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
        """
        ...

    def replace_module(self) -> None:
        """
        Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
        """
        ...

User-facing API

We only expose a limited number of APIs to the user to keep their user experience simple and clean.

class ShardFormer:
    """
    Parallelize model based on the given config and policy

    Example:

    org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    shard_config = ShardConfig()
    shard_former = ShardFormer(shard_config=shard_config)
    model, shared_params = shard_former.optimize(org_model)

    """

    def __init__(self, shard_config: ShardConfig):
        """
        Do two things:
        1. Create a distribute coordinator
        2. serve as a store for shard config
        """
        self.shard_config = shard_config
        self.coordinator = DistCoordinator()

    def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
        r"""
        This method will optimize the model based on the given policy.

        Args:
            model (`torch.nn.Model`): the origin huggingface model
            shard_config (`ShardConfig`): the config for distribute information
            policy (`Policy`): the custom policy for sharding

        Returns: the sharded model and the shared parameters
        """
        sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
        shared_params = sharder.shard()
        return model, shared_params

⌨️ Development Notes

Add New Policy to Shardformer

This section serves as the guideline for writing new policies and register them into shardformer.

  • Step 1. Write your own model policy

You can create a new file in the colossalai/shardformer/policies folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as transformers should be imported only in the function body when needed.

Please follow the following protocols when writing your policy:

  • You have to make a clear decision what you want to replace exactly in the original PyTorch module

    • Use ModulePolicyDescription.attribute_replacement to replace the module attributes
    • Use ModulePolicyDescription.param_replacement to replace the module parameters
    • Use ModulePolicyDescription.sub_module_replacement to replace the submodules completely. The target module should implement the from_native_module for the replacement.
    • Use ModulePolicyDescription.method_replacement to replace the module methods. These replacement methods should be put in the shardformer/modeling/<model-name>.py.
  • You can implement the ParallelModule for primitive modules in the shardformer/layer/<model-name>.py file. Primitive modules refer to modules which are not composed of other modules. For example, the torch.nn.Linear module is a primitive module while modules such as BertEncoder module in the transformers library is a composite module. Primitive modules do not nested inner nn.Module members. For composite modules, you should consider using ModulePolicyDescription to implement your replacement.

  • ParallelModule is meant to be used in two ways: ParallelModule.from_native_module to convert native PyTorch module to the ParallelModule and ParallelModule(...) to instantiate the module directly just like a normal PyTorch module. ParallelModule should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the ModulePolicyDescription.sub_module_replacement and there is no weight sharding in your module, you can just implement the from_native_module method without inheriting the ParallelModule like colossalai/shardformer/layer/normalization.py.

  • Do not import any file in the colossalai/shardformer/policies and colossalai/shardformer/modeling to avoid unwanted import error. For example, a file in these folders accidentally imports transformers library at the top of the file, then the user will have to install transformers library even if they do not use this file. Any file in the modeling folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the ShardFormer module.

  • Try to keep your import statement on third-party libraries such as transformers within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.

  • Step 2. Register your policy to the autopolicy

Next, you need to register your policy in the colossalai/shardformer/policies/autopolicy.py file.

For example, if we register the policy for the BERT model, we just add a key-value in the _POLICY_LIST dictionary. The key if the qualname of the model object (you can get it by model.__class__.__qualname__). The value is a PolicyLocation object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as transformers) which we do not want to import when we do not use the policy.

_POLICY_LIST = {
    # BERT
    "transformers.models.bert.modeling_bert.BertModel":
        PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
}

Write Your Unit Testing

This section serves as the guideline for testing the shardformer module.

  • Step 1. Add your model to the model zoo in the test kits.

Add your model to the tests/kit/model_zoo file. This allows you to define test-related components for this model. You can take tests/kit/model_zoo/transformers/llama.py as an example for reference.

  • Step 2. Write your unit testing for the model

Next, implement your unit test in the tests/test_shardformer folder. Please refer to other similar tests for style consistency.

  • Step 3. Execute your test

When you run tests locally, you should run tests for both your newly-added test file and the whole shardformer module tests.

# test for your own test file
pytest tests/test_shardformer/test_model/<your-file>.py

# test for the whole shardformer module
pytest tests/test_shardformer

📊 Benchmarking

System Performance

We conducted benchmark tests to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.

We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.

In the case of using 2 GPUs, the training times are as follows.

N_CTX org_model shard_model
256 11.2ms 17.2ms
512 9.8ms 19.5ms
1024 19.6ms 18.9ms
2048 46.6ms 30.8ms
4096 160.5ms 90.4ms


In the case of using 4 GPUs, the training times are as follows.

N_CTX org_model shard_model
256 10.0ms 21.1ms
512 11.5ms 20.2ms
1024 22.1ms 20.6ms
2048 46.9ms 24.8ms
4096 160.4ms 68.0ms


As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.

Convergence

To validate that training the model using shardformers does not impact its convergence. We fine-tuned the BERT model using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.

the configurations are as follows:

batch_size = 2
epoch = 3
lr = 2.4e-5
accumulation_steps = 8
warmup_fraction = 0.03
accuracy f1 loss GPU number model sharded
0.82971 0.87713 0.23194 4 True
0.83797 0.88006 0.22683 2 True
0.84521 0.88700 0.21822 1 False

Overall, the results demonstrate that using shardformers during model training does not affect the convergence.