2023-05-24 03:51:48 +00:00
# ⚡️ ShardFormer
2023-05-24 02:26:46 +00:00
2023-05-24 03:51:48 +00:00
## 📚 Table of Contents
- [⚡️ ShardFormer ](#️ -shardformer )
- [📚 Table of Contents ](#-table-of-contents )
- [🔗 Introduction ](#-introduction )
- [🔨 Usage ](#-usage )
2023-06-16 08:15:10 +00:00
- [Quick Start ](#quick-start )
- [Write your own policy ](#write-your-own-policy )
- [🗺 Roadmap ](#-roadmap )
- [💡 API Design ](#-api-design )
- [Distributed Modules ](#distributed-modules )
- [Shard Config ](#shard-config )
- [Policy ](#policy )
- [Model Sharder ](#model-sharder )
- [User-facing API ](#user-facing-api )
2023-05-24 10:02:54 +00:00
2023-05-24 03:51:48 +00:00
## 🔗 Introduction
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
## 🔨 Usage
2023-06-16 08:15:10 +00:00
### Quick Start
2023-05-24 03:51:48 +00:00
The sample API usage is given below:
2023-05-24 02:26:46 +00:00
``` python
2023-06-16 08:15:10 +00:00
from colossalai.shardformer import ShardConfig, Shard
2023-05-24 02:26:46 +00:00
from transformers import BertForMaskedLM
2023-06-16 08:15:10 +00:00
# launch colossalai
colossalai.launch_from_torch()
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
2023-05-24 03:51:48 +00:00
2023-06-16 08:15:10 +00:00
# create huggingface model as normal
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model).to('cuda')
# do everything like normal
...
2023-05-24 03:51:48 +00:00
```
2023-06-16 08:15:10 +00:00
### Write your own policy
2023-05-24 03:51:48 +00:00
2023-06-16 08:15:10 +00:00
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design ](#-api-design ).
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
```python
from colossalai.shardformer import Policy
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
class MyPolicy(Policy):
# implement your own policy
...
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
# init model and shard former
...
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
# use customized policy to shard model
my_policy = MyPolicy()
shard_former.shard_model(model, my_policy)
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
```
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
## 🗺 Roadmap
We will follow this roadmap to develop Shardformer:
- [x] API Design
- [x] API Implementation
- [x] Unit Testing
- [ ] Policy Implementation
- [ ] Hugging Face
- [ ] NLP
- [x] BERT
2023-06-19 09:57:37 +00:00
- [x] T5
- [x] LlaMa
2023-06-16 08:15:10 +00:00
- [ ] GPT2
- [ ] BLOOM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] ViT
- [ ] BEiT
- [ ] SwinTransformer
- [ ] SwinTransformer V2
- [ ] Audio
- [ ] To be added
- [ ] Multi-modal
- [ ] To be added
## 💡 API Design
We will discuss the major components of `ShardFormer` below to help you better understand how things work.
This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation.
Please refer to the code for more details.
< p align = "center" >
< img src = "https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width = "600" / >
< br / >
< b > This diagram is deprecated, need to update it< / b >
< / p >
### Distributed Modules
`ShardFormer` replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
```python
class ParallelModule(torch.nn.Module):
@abstractmethod
def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule
2023-05-24 02:26:46 +00:00
"""
2023-06-16 08:15:10 +00:00
Convert a native module to a parallelized
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
Examples:
2023-05-24 02:26:46 +00:00
2023-06-16 08:15:10 +00:00
```python
# replace module
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
2023-05-24 02:26:46 +00:00
"""
```
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
### Shard Config
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
`ShardConfig` is a simple data class to tell `ShardFormer` how sharding will be performed.
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
```python
@dataclass
class ShardConfig:
data_parallel_size: int
tensor_parallel_size: int
...
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
# Some possible future config fields
pipeline_parallel_size: int # Support pipeline parallelism
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
gather_output: bool # gather the model output
use_flash_attention: bool # whether to use flash attention to speed up attention
```
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
### Policy
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
The `Policy` class describes how to handle the model sharding.
It is merely a description, the actual sharding will be performed by `ModelSharder` .
We abstract the policy into four stages:
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
3. Providing `ModulePolicyDescription` : call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
``` python
@dataclass
class ModulePolicyDescription:
"""
Describe how the attributes and parameters will be transformed in a policy
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
"""
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
@dataclass
class SubModuleReplacementDescription:
"""
Describe how a submodule will be replaced
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
"""
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
class Policy(ABC):
def __init__ (self)
self.model = None
def set_model(self, model: nn.Module) -> None:
"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
"""
self.model = model
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
@abstractmethod
def preprocess(self) -> nn.Module:
"""
Perform some preprocessing on the model, such as resizing the embedding size
"""
...
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
"""
...
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
"""
replace the class of the model to substitute the forward and attributes
"""
...
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
@abstractmethods
def postprocess(self) -> nn.Module:
"""
Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
"""
...
```
2023-05-24 10:02:54 +00:00
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
### Model Sharder
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
`ModelSharder` is the class in charge of sharding the model based on the given policy.
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
```python
class ModelSharder:
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
def __init__ (self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None)
#TODO: input is a cls or a obj
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
def shard(self) -> None:
"""
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
def replace_model_class(self) -> None:
"""
Replace the model's methods and attributes with our own defined class.
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
E.g. we can replace the forward function of the original BertForMaskedLM object
with the forward function we define in BertForMaskedLM_ class.
"""
...
2023-06-12 08:52:18 +00:00
2023-06-16 08:15:10 +00:00
def replace_module(self) -> None:
"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
"""
...
```
2023-05-24 10:02:54 +00:00
2023-06-16 08:15:10 +00:00
### User-facing API
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
We only expose a limited number of APIs to the user to keep their user experience simple and clean.
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
```python
class ShardFormer:
"""
Parallelize model based on the given config and policy
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
Example:
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(model, policy=policy)
dataloader = shard_former.shard_dataset(dataset)
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
"""
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
def __init__ (self, shard_config: ShardConfig):
"""
Do two things:
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
2. serve as a store for shard config
"""
self.shard_config = shard_config
self.pg_manager = None
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
"""
Initialize the distributed process group according to the
"""
pg_manager = ...
self.pg_manager = pg_manager
return pg_manager
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
def shard_model(self, model: torch.nn.Module, policy: Policy) -> torch.nn.Module:
"""
Shard model for TP and PP
"""
...
2023-06-01 08:21:02 +00:00
2023-06-16 08:15:10 +00:00
def shard_dataset(self, dataset: Dataset) -> Dataloader:
"""
Shard dataset for DP
"""
...
```