You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/shardformer
Baizhou Zhang 2c787d7f47
[shardformer] fix submodule replacement bug when enabling pp (#4544)
1 year ago
..
examples
layer [shardformer] Add overlap support for gpt2 (#4535) 1 year ago
modeling [shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498) 1 year ago
policies [shardformer] fix opt test hanging (#4521) 1 year ago
shard [shardformer] fix submodule replacement bug when enabling pp (#4544) 1 year ago
README.md
__init__.py
_utils.py

README.md

ShardFormer

📚 Table of Contents

🔗 Introduction

Shardformer is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.

🔨 Usage

Quick Start

The sample API usage is given below(If you enable the use of flash attention, please install flash_attn. In addition, xformers's cutlass_op provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):

from colossalai.shardformer import ShardConfig, Shard
from transformers import BertForMaskedLM

# launch colossalai
colossalai.launch_from_torch()

# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)

# create huggingface model as normal
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).to('cuda')

# do everything like normal
...

Write your own policy

If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in API Design.

from colossalai.shardformer import Policy

class MyPolicy(Policy):
    # implement your own policy
    ...

# init model and shard former
...

# use customized policy to shard model
my_policy = MyPolicy()
shard_former.optimize(model, my_policy)



🗺 Roadmap

We will follow this roadmap to develop Shardformer:

  • API Design
  • API Implementation
  • Unit Testing
  • Policy Implementation
    • Hugging Face
      • NLP
        • BERT
        • T5
        • LlaMa
        • GPT2
        • OPT
        • BLOOM
        • GLM
        • RoBERTa
        • ALBERT
        • ERNIE
        • GPT Neo
        • GPT-J
      • CV
        • ViT
        • BEiT
        • SwinTransformer
        • SwinTransformer V2
      • Audio
        • Whisper
      • Multi-modal
        • SAM
        • BLIP-2
  • Flash Attention Support
    • NLP
      • BERT
      • T5
      • LlaMa
      • GPT2
      • OPT
      • BLOOM
      • GLM
      • RoBERTa
      • ALBERT
      • ERNIE
      • GPT Neo
      • GPT-J

💡 API Design

We will discuss the major components of ShardFormer below to help you better understand how things work. This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation. Please refer to the code for more details.


Distributed Modules

ShardFormer replaces the original PyTorch module with a distributed module. The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new forward function to execute distributed computation. Each distributed module implements its from_native_module static method to convert the PyTorch module to its corresponding distributed module.

class ParallelModule(torch.nn.Module):

    @abstractmethod
    def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule
        """
        Convert a native module to a parallelized

        Examples:

        ```python
        # replace module
        my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
        ```
        """

Shard Config

ShardConfig is a simple data class to tell ShardFormer how sharding will be performed.

@dataclass
class ShardConfig:
    tensor_parallel_process_group: ProcessGroup = None
    enable_fused_normalization: bool = False
    ...

    # Some possible future config fields
    tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
    inference_only: bool # only inject inference-suitable sharding policy
    use_flash_attention: bool # whether to use flash attention to speed up attention

Policy

The Policy class describes how to handle the model sharding. It is merely a description, the actual sharding will be performed by ModelSharder. We abstract the policy into four stages:

  1. Preprocessing: call Policy.preprocess to do some prior work before sharding, for example, resizing the embedding
  2. Providing ModulePolicyDescription: call Policy.module_policy to get a bunch of ModulePolicyDescription to tell ModelSharder how the submodules's attributes, child parameters, and deeper submodules will be substituted.
  3. Postprocessing: call Policy.postprocess to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
@dataclass
class ModulePolicyDescription:
    r"""
    Describe how the attributes and parameters will be transformed in a policy.

    Args:
        attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
        param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
        sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
                    object which specifies the module to be replaced and the target module used to replacement.
        method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
    """
    attribute_replacement: Dict[str, Any] = None
    param_replacement: List[Callable] = None
    sub_module_replacement: List[SubModuleReplacementDescription] = None
    method_replacement: Dict[str, Callable] = None

@dataclass
class SubModuleReplacementDescription:
    r"""
    Describe how a submodule will be replaced

    Args:
        suffix (str): used to get the submodule object
        target_module (ParallelModule): specifies the module class used to replace to submodule
        kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
        ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
    """
    suffix: str
    target_module: ParallelModule
    kwargs: Dict[str, Any] = None
    ignore_if_not_exist: bool = False


class Policy(ABC):

    def __init__(self)
        self.model = None

    def set_model(self, model: nn.Module) -> None:
        """
        Set model as an attribute of the Policy object so that we can access the model's attributes.
        """
        self.model = model

    @abstractmethod
    def preprocess(self) -> nn.Module:
        """
        Perform some preprocessing on the model, such as resizing the embedding size
        """
        ...

    @abstractmethod
    def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
        """
        Return the dict for the modify policy, the key is the original layer class and the value is the
        argument for the modify layer
        """
        ...

    @abstractmethods
    def postprocess(self) -> nn.Module:
        """
        Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
        """
        ...

Model Sharder

ModelSharder is the class in charge of sharding the model based on the given policy.

class ModelSharder:

    def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
        #TODO: input is a cls or a obj
        ...

    def shard(self) -> None:
        """
        Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
        """
        ...

    def replace_module(self) -> None:
        """
        Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
        """
        ...

User-facing API

We only expose a limited number of APIs to the user to keep their user experience simple and clean.

class ShardFormer:
    """
    Parallelize model based on the given config and policy

    Example:

    shard_former = ShardFormer(shard_config=shard_config)
    shard_former.init_distributed()
    model = shard_former.optimize(model, policy=policy)
    dataloader = shard_former.shard_dataset(dataset)

    """

    def __init__(self, shard_config: ShardConfig):
        """
        Do two things:
        1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
        2. serve as a store for shard config
        """
        self.shard_config = shard_config
        self.pg_manager = None

    def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
        """
        Initialize the distributed process group according to the
        """
        pg_manager = ...
        self.pg_manager = pg_manager
        return pg_manager

    def shard_model(self, model: torch.nn.Modulepolicy: Policy) -> torch.nn.Module:
        """
        Shard model for TP and PP
        """
        ...

    def shard_dataset(self, dataset: Dataset) -> Dataloader:
        """
        Shard dataset for DP
        """
        ...

⌨️ Development Notes

Add New Policy to Shardformer

This section serves as the guideline for writing new policies and register them into shardformer.

  • Step 1. Write your own model policy

You can create a new file in the colossalai/shardformer/policies folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as transformers should be imported only in the function body when needed.

Please follow the following protocols when writing your policy:

  • You have to make a clear decision what you want to replace exactly in the original PyTorch module

    • Use ModulePolicyDescription.attribute_replacement to replace the module attributes
    • Use ModulePolicyDescription.param_replacement to replace the module parameters
    • Use ModulePolicyDescription.sub_module_replacement to replace the submodules completely. The target module should implement the from_native_module for the replacement.
    • Use ModulePolicyDescription.method_replacement to replace the module methods. These replacement methods should be put in the shardformer/modeling/<model-name>.py.
  • You can implement the ParallelModule for primitive modules in the shardformer/layer/<model-name>.py file. Primitive modules refer to modules which are not composed of other modules. For example, the torch.nn.Linear module is a primitive module while modules such as BertEncoder module in the transformers library is a composite module. Primitive modules do not nested inner nn.Module members. For composite modules, you should consider using ModulePolicyDescription to implement your replacement.

  • ParallelModule is meant to be used in two ways: ParallelModule.from_native_module to convert native PyTorch module to the ParallelModule and ParallelModule(...) to instantiate the module directly just like a normal PyTorch module. ParallelModule should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the ModulePolicyDescription.sub_module_replacement and there is no weight sharding in your module, you can just implement the from_native_module method without inheriting the ParallelModule like colossalai/shardformer/layer/normalization.py.

  • Do not import any file in the colossalai/shardformer/policies and colossalai/shardformer/modeling to avoid unwanted import error. For example, a file in these folders accidentally imports transformers library at the top of the file, then the user will have to install transformers library even if they do not use this file. Any file in the modeling folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the ShardFormer module.

  • Try to keep your import statement on third-party libraries such as transformers within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.

  • Step 2. Register your policy to the autopolicy

Next, you need to register your policy in the colossalai/shardformer/policies/autopolicy.py file.

For example, if we register the policy for the BERT model, we just add a key-value in the _POLICY_LIST dictionary. The key if the qualname of the model object (you can get it by model.__class__.__qualname__). The value is a PolicyLocation object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as transformers) which we do not want to import when we do not use the policy.

_POLICY_LIST = {
    # BERT
    "transformers.models.bert.modeling_bert.BertModel":
        PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
}

Write Your Unit Testing

This section serves as the guideline for testing the shardformer module.

  • Step 1. Add your model to the model zoo in the test kits.

Add your model to the tests/kit/model_zoo file. This allows you to define test-related components for this model. You can take tests/kit/model_zoo/transformers/llama.py as an example for reference.

  • Step 2. Write your unit testing for the model

Next, implement your unit test in the tests/test_shardformer folder. Please refer to other similar tests for style consistency.

  • Step 3. Execute your test

When you run tests locally, you should run tests for both your newly-added test file and the whole shardformer module tests.

# test for your own test file
pytest tests/test_shardformer/test_model/<your-file>.py

# test for the whole shardformer module
pytest tests/test_shardformer

📊 Benchmarking

System Performance

We conducted benchmark tests to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.

We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.

In the case of using 2 GPUs, the training times are as follows.

N_CTX org_model shard_model
256 11.2ms 17.2ms
512 9.8ms 19.5ms
1024 19.6ms 18.9ms
2048 46.6ms 30.8ms
4096 160.5ms 90.4ms


In the case of using 4 GPUs, the training times are as follows.

N_CTX org_model shard_model
256 10.0ms 21.1ms
512 11.5ms 20.2ms
1024 22.1ms 20.6ms
2048 46.9ms 24.8ms
4096 160.4ms 68.0ms


As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.

Convergence

To validate that training the model using shardformers does not impact its convergence. We fine-tuned the BERT model using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

accuracy f1 loss GPU number model shard
0.82594 0.87441 0.09913 4 True
0.81884 0.87299 0.10120 2 True
0.81855 0.87124 0.10357 1 False

Overall, the results demonstrate that using shardformers during model training does not affect the convergence.