[shardformer] refactored some doc and api (#4137)

* [shardformer] refactored some doc and api

* polish code
pull/4157/head
Frank Lee 2023-07-03 15:29:11 +08:00
parent 7f9b30335b
commit 74257cb446
15 changed files with 355 additions and 490 deletions

View File

@ -15,7 +15,12 @@
- [Policy](#policy)
- [Model Sharder](#model-sharder)
- [User-facing API](#user-facing-api)
- [Shardformer Convergence](#shardformer-convergence)
- [⌨️ Development Notes](#-development-notes)
- [Add New Policy to Shardformer](#add-new-policy-to-shardformer)
- [Write Your Unit Testing](#write-your-unit-testing)
- [📊 Benchmarking](#-benchmarking)
- [System Performance](#system-performance)
- [Convergence](#convergence)
## 🔗 Introduction
@ -40,12 +45,9 @@ config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
# create huggingface model as normal
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
gather_output=True)
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model).to('cuda')
sharded_model = shard_former.optimize(model).to('cuda')
# do everything like normal
...
@ -67,10 +69,11 @@ class MyPolicy(Policy):
# use customized policy to shard model
my_policy = MyPolicy()
shard_former.shard_model(model, my_policy)
shard_former.optimize(model, my_policy)
```
## 🗺 Roadmap
We will follow this roadmap to develop Shardformer:
@ -112,7 +115,6 @@ Please refer to the code for more details.
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" />
<br/>
<b>This diagram is deprecated, need to update it</b>
</p>
@ -147,15 +149,13 @@ class ParallelModule(torch.nn.Module):
```python
@dataclass
class ShardConfig:
data_parallel_size: int
tensor_parallel_size: int
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
...
# Some possible future config fields
pipeline_parallel_size: int # Support pipeline parallelism
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
gather_output: bool # gather the model output
use_flash_attention: bool # whether to use flash attention to speed up attention
```
@ -166,42 +166,42 @@ It is merely a description, the actual sharding will be performed by `ModelShard
We abstract the policy into four stages:
1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
``` python
@dataclass
class ModulePolicyDescription:
"""
Describe how the attributes and parameters will be transformed in a policy
r"""
Describe how the attributes and parameters will be transformed in a policy.
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None
method_replacement: Dict[str, Callable] = None
@dataclass
class SubModuleReplacementDescription:
"""
r"""
Describe how a submodule will be replaced
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
class Policy(ABC):
@ -230,13 +230,6 @@ class Policy(ABC):
"""
...
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
"""
replace the class of the model to substitute the forward and attributes
"""
...
@abstractmethods
def postprocess(self) -> nn.Module:
"""
@ -253,8 +246,9 @@ class Policy(ABC):
```python
class ModelSharder:
def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None)
def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
#TODO: input is a cls or a obj
...
def shard(self) -> None:
"""
@ -262,15 +256,6 @@ class ModelSharder:
"""
...
def replace_model_class(self) -> None:
"""
Replace the model's methods and attributes with our own defined class.
E.g. we can replace the forward function of the original BertForMaskedLM object
with the forward function we define in BertForMaskedLM_ class.
"""
...
def replace_module(self) -> None:
"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
@ -291,7 +276,7 @@ class ShardFormer:
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(model, policy=policy)
model = shard_former.optimize(model, policy=policy)
dataloader = shard_former.shard_dataset(dataset)
"""
@ -326,14 +311,69 @@ class ShardFormer:
...
```
### Shardformer Convergence
## ⌨️ Development Notes
### Add New Policy to Shardformer
This section serves as the guideline for writing new policies and register them into `shardformer`.
- Step 1. Write your own model policy
You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed.
- Step 2. Register your policy to the autopolicy
Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
```python
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
}
```
### Write Your Unit Testing
This section serves as the guideline for testing the `shardformer` module.
- Step 1. Add your model to the model zoo in the test kits.
Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference.
- Step 2. Write your unit testing for the model
Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
- Step 3. Execute your test
When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.
```bash
# test for your own test file
pytest tests/test_shardformer/test_model/<your-file>.py
# test for the whole shardformer module
pytest tests/test_shardformer
```
## 📊 Benchmarking
### System Performance
To be added.
### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| :-----: | :----: | :----: | :----: | :----: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

View File

@ -51,7 +51,7 @@ def train(args):
if dist.get_world_size() > 1:
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.shard_model(model)
model = shard_former.optimize(model)
optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps

View File

@ -22,9 +22,11 @@ class SubModuleReplacementDescription:
r"""
Describe how a submodule will be replaced
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
@ -35,47 +37,37 @@ class SubModuleReplacementDescription:
@dataclass
class ModulePolicyDescription:
r"""
Describe how the attributes and parameters will be transformed in a policy
Describe how the attributes and parameters will be transformed in a policy.
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
must receive two arguments: module, process_group. One example is
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
must receive only one arguments: module. One example is
```python
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
```
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
the module to be replaced and the target module used to replacement
```python
def example_replace_weight(module: torch.nn.Module):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
```
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
method_replacement: List[Callable] = None
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None
method_replacement: Dict[str, Callable] = None
class Policy(ABC):
r"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
CustomPolicy:
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
class for the example.
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self) -> None:
@ -106,63 +98,24 @@ class Policy(ABC):
def config_sanity_check(self):
"""
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
This method is made abstractmethod with no default implementation because we want to the policy writer
to take note of the feature supported by his/her model and policy.
"""
pass
@abstractmethod
def preprocess(self) -> nn.Module:
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
Perform some preprocessing of the model, like reshaping the embedding layer.
"""
pass
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): ModulePolicyDescription(
attribute_replacement = {
"attribute1": value1,
"attribute2": value2,
...
},
param_replacement = [
function1,
function2,
...
],
sub_module_replacement = [
`SubModuleReplacementDescription` description1,
`SubModuleReplacementDescription` description2,
...
]
),
origin layer class2 (nn.Module): ModulePolicyDescription(
...
),
...
}
"""
pass
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
r"""
Return the new model class for the new model, None means no need to modify the model class
Return:
New model class
E.g.
```
return BertModel_
```
This method returns the module policy, which is a dictionary. The key is the module name or the module object,
and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module
will be transformed.
"""
pass

View File

@ -48,7 +48,6 @@ class BertPolicy(Policy):
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
@ -88,18 +87,16 @@ class BertPolicy(Policy):
)
]),
BertEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
}
# optimization configuration
@ -121,10 +118,6 @@ class BertPolicy(Policy):
),)
return base_policy
def new_model_class(self):
# do nothing
return None
def postprocess(self):
return self.model
@ -148,13 +141,10 @@ class BertForPretrainingPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True}),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
@ -191,13 +181,10 @@ class BertLMHeadModelPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True}),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
@ -230,13 +217,10 @@ class BertForMaskedLMPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True}),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
@ -272,14 +256,12 @@ class BertForSequenceClassificationPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
@ -297,14 +279,12 @@ class BertForTokenClassificationPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
@ -329,14 +309,12 @@ class BertForMultipleChoicePolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy

View File

@ -98,7 +98,6 @@ class BloomPolicy(Policy):
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
@ -125,7 +124,6 @@ class BloomPolicy(Policy):
ModulePolicyDescription(attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
@ -160,10 +158,6 @@ class BloomPolicy(Policy):
return base_policy
def new_model_class(self):
# do nothing
return None
def postprocess(self):
return self.model
@ -180,13 +174,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
# add a new item for casual lm
new_item = {
BloomForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@ -213,13 +204,10 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@ -233,17 +221,14 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
}
policy.update(new_item)
return policy

View File

@ -31,23 +31,20 @@ class GPT2Policy(Policy):
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
return {
base_policy = {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
]),
GPT2Block:
ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
@ -110,9 +107,6 @@ class GPT2Policy(Policy):
return base_policy
def new_model_class(self):
return self.model
def postprocess(self):
return self.model
@ -136,13 +130,10 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
@ -169,13 +160,10 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy

View File

@ -28,7 +28,7 @@ class LlamaPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
return {
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@ -37,7 +37,6 @@ class LlamaPolicy(Policy):
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
@ -70,14 +69,12 @@ class LlamaPolicy(Policy):
],
),
LlamaModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
}
# optimization configuration
@ -101,9 +98,6 @@ class LlamaPolicy(Policy):
return base_policy
def new_model_class(self):
return None
def postprocess(self):
return self.model
@ -117,13 +111,10 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@ -139,13 +130,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy

View File

@ -31,33 +31,28 @@ class OPTPolicy(Policy):
base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]),
OPTDecoderLayer:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
]),
OPTAttention:
ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
@ -95,9 +90,6 @@ class OPTPolicy(Policy):
return base_policy
def new_model_class(self):
return None
def postprocess(self):
return self.model
@ -116,13 +108,10 @@ class OPTForCausalLMPolicy(OPTPolicy):
policy = super().module_policy()
new_item = {
OPTForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)

View File

@ -44,36 +44,30 @@ class T5BasePolicy(Policy):
base_policy = {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
]),
T5LayerSelfAttention:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
T5LayerCrossAttention:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
T5Attention:
ModulePolicyDescription(attribute_replacement={
"d_model":
@ -83,7 +77,6 @@ class T5BasePolicy(Policy):
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
@ -107,51 +100,44 @@ class T5BasePolicy(Policy):
ignore_if_not_exist=True)
]),
T5LayerFF:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
T5DenseGatedActDense:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="wo",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
T5DenseActDense:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
}
# optimization configuration
@ -167,9 +153,6 @@ class T5BasePolicy(Policy):
return base_policy
def new_model_class(self):
return None
def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
@ -185,14 +168,12 @@ class T5ModelPolicy(T5BasePolicy):
from transformers import T5Model
base_policy = super().module_policy()
base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
return base_policy
@ -202,18 +183,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
return policy
def postprocess(self):
@ -235,14 +212,12 @@ class T5EncoderPolicy(T5BasePolicy):
from transformers import T5EncoderModel
base_policy = super().module_policy()
base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
return base_policy
def postprocess(self):

View File

@ -28,16 +28,14 @@ class ViTPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
return {
base_policy = {
ViTEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
]),
ViTLayer:
ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
@ -45,7 +43,6 @@ class ViTPolicy(Policy):
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",

View File

@ -3,8 +3,6 @@ from dataclasses import dataclass
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.cluster.dist_coordinator import DistCoordinator
__all__ = ['ShardConfig']
@ -15,7 +13,9 @@ class ShardConfig:
Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
"""
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
@ -45,4 +45,4 @@ class ShardConfig:
Turn on all optimization.
"""
# you can add all the optimization flag here
self.fused_layernorm = True
self.enable_fused_normalization = True

View File

@ -1,9 +1,7 @@
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Union
import torch.nn as nn
from colossalai.cluster.process_group_manager import ProcessGroupManager
from .._utils import getattr_, setattr_
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
@ -34,7 +32,6 @@ class ModelSharder(object):
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
self._replace_model_class()
self._replace_module()
self._postprocess()
@ -44,27 +41,6 @@ class ModelSharder(object):
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def _replace_model_class(self,) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
e.g.
::
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
new_model_class = self.policy.new_model_class()
if new_model_class is None:
return
for key in new_model_class.__dict__.keys():
if hasattr(self.model.__class__, key):
setattr(
self.model.__class__,
key,
getattr(new_model_class, key),
)
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
@ -73,19 +49,18 @@ class ModelSharder(object):
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions = self.policy.module_policy()
for module_description in module_descriptions.items():
origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement
param_replacement = module_description[1].param_replacement
sub_module_replacement = module_description[1].sub_module_replacement
method_replacement = module_description[1].method_replacement
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
for layer_cls, module_description in module_descriptions.items():
attr_replacement = module_description.attribute_replacement
param_replacement = module_description.param_replacement
sub_module_replacement = module_description.sub_module_replacement
method_replacement = module_description.method_replacement
self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
method_replacement, sub_module_replacement)
def _recursive_replace_layer(
self,
module: nn.Module,
origin_cls: nn.Module,
origin_cls: Union[str, nn.Module],
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
@ -95,17 +70,25 @@ class ModelSharder(object):
Reverse the replace layer operation
Args:
layer (:class:`torch.nn.Module`): The object of layer to shard
origin_cls (:class:`transformers.model`): The origin layer class
layer (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
"""
if module.__class__ == origin_cls:
self._replace_attr(module, attr_replacement)
self._replace_param(module, param_replacement)
self._replace_method(module, method_replacement)
self._replace_sub_module(module, sub_module_replacement)
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None:
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None:
self._replace_sub_module(module, sub_module_replacement)
for name, child in module.named_children():
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
@ -138,13 +121,10 @@ class ModelSharder(object):
layer (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy
"""
# TODO: support parameter shard
pass
for param_func in param_replacement:
param_func(module)
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
if method_replacement is None:
return
for method_name, new_method in method_replacement.items():
# bind the new method to the module
setattr(module, method_name, new_method.__get__(module, module.__class__))
@ -158,8 +138,8 @@ class ModelSharder(object):
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
"""
for description in sub_module_replacement:

View File

@ -22,27 +22,19 @@ class ShardFormer:
colossalai.launch_from_torch(config={})
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_mode='1d',
)
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.shard_model(org_model)
model = shard_former.optimize(org_model)
```
"""
def __init__(self, shard_config: ShardConfig):
"""
Do two things:
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
2. serve as a store for
"""
self.coordinator = DistCoordinator()
self.shard_config = shard_config
def shard_model(self, model: nn.Module, policy: Policy = None):
def optimize(self, model: nn.Module, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.
This method will optimize the model based on the given policy.
Args:
model (`torch.nn.Model`): the origin huggingface model

View File

@ -11,7 +11,7 @@ def build_model(model_fn):
shard_config = ShardConfig(enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.shard_model(model_copy).cuda()
sharded_model = shard_former.optimize(model_copy).cuda()
return org_model, sharded_model

View File

@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
sharded_model = shardformer.shard_model(model)
sharded_model = shardformer.optimize(model)
# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)