mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactored some doc and api (#4137)
* [shardformer] refactored some doc and api * polish codepull/4157/head
parent
7f9b30335b
commit
74257cb446
|
@ -15,7 +15,12 @@
|
|||
- [Policy](#policy)
|
||||
- [Model Sharder](#model-sharder)
|
||||
- [User-facing API](#user-facing-api)
|
||||
- [Shardformer Convergence](#shardformer-convergence)
|
||||
- [⌨️ Development Notes](#️-development-notes)
|
||||
- [Add New Policy to Shardformer](#add-new-policy-to-shardformer)
|
||||
- [Write Your Unit Testing](#write-your-unit-testing)
|
||||
- [📊 Benchmarking](#-benchmarking)
|
||||
- [System Performance](#system-performance)
|
||||
- [Convergence](#convergence)
|
||||
|
||||
|
||||
## 🔗 Introduction
|
||||
|
@ -40,12 +45,9 @@ config = BertConfig.from_pretrained('bert-base-uncased')
|
|||
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
||||
|
||||
# create huggingface model as normal
|
||||
shard_config = ShardConfig(tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
gather_output=True)
|
||||
shard_config = ShardConfig()
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model).to('cuda')
|
||||
sharded_model = shard_former.optimize(model).to('cuda')
|
||||
|
||||
# do everything like normal
|
||||
...
|
||||
|
@ -67,10 +69,11 @@ class MyPolicy(Policy):
|
|||
|
||||
# use customized policy to shard model
|
||||
my_policy = MyPolicy()
|
||||
shard_former.shard_model(model, my_policy)
|
||||
shard_former.optimize(model, my_policy)
|
||||
|
||||
|
||||
|
||||
```
|
||||
|
||||
## 🗺 Roadmap
|
||||
|
||||
We will follow this roadmap to develop Shardformer:
|
||||
|
@ -112,7 +115,6 @@ Please refer to the code for more details.
|
|||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" />
|
||||
<br/>
|
||||
<b>This diagram is deprecated, need to update it</b>
|
||||
</p>
|
||||
|
||||
|
||||
|
@ -147,15 +149,13 @@ class ParallelModule(torch.nn.Module):
|
|||
```python
|
||||
@dataclass
|
||||
class ShardConfig:
|
||||
data_parallel_size: int
|
||||
tensor_parallel_size: int
|
||||
tensor_parallel_process_group: ProcessGroup = None
|
||||
enable_fused_normalization: bool = False
|
||||
...
|
||||
|
||||
# Some possible future config fields
|
||||
pipeline_parallel_size: int # Support pipeline parallelism
|
||||
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
|
||||
inference_only: bool # only inject inference-suitable sharding policy
|
||||
gather_output: bool # gather the model output
|
||||
use_flash_attention: bool # whether to use flash attention to speed up attention
|
||||
```
|
||||
|
||||
|
@ -166,42 +166,42 @@ It is merely a description, the actual sharding will be performed by `ModelShard
|
|||
We abstract the policy into four stages:
|
||||
|
||||
1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
|
||||
2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
|
||||
3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
|
||||
4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
|
||||
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
|
||||
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
|
||||
|
||||
``` python
|
||||
@dataclass
|
||||
class ModulePolicyDescription:
|
||||
"""
|
||||
Describe how the attributes and parameters will be transformed in a policy
|
||||
r"""
|
||||
Describe how the attributes and parameters will be transformed in a policy.
|
||||
|
||||
Args:
|
||||
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
|
||||
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
|
||||
def example_replace_weight(module: torch.nn.Module, process_group):
|
||||
weight = module.weight
|
||||
new_weight = shard_rowwise(weight, process_group)
|
||||
module.weight = torch.nn.Parameter(new_weight)
|
||||
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
|
||||
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
|
||||
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
|
||||
object which specifies the module to be replaced and the target module used to replacement.
|
||||
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
|
||||
"""
|
||||
attribute_replacement: Dict[str, Any]
|
||||
param_replacement: List[Callable]
|
||||
sub_module_replacement: List[SubModuleReplacementDescription]
|
||||
attribute_replacement: Dict[str, Any] = None
|
||||
param_replacement: List[Callable] = None
|
||||
sub_module_replacement: List[SubModuleReplacementDescription] = None
|
||||
method_replacement: Dict[str, Callable] = None
|
||||
|
||||
@dataclass
|
||||
class SubModuleReplacementDescription:
|
||||
"""
|
||||
r"""
|
||||
Describe how a submodule will be replaced
|
||||
|
||||
Args:
|
||||
suffix (str): used to get the submodule object
|
||||
target_module (ParallelModule): specifies the module class used to replace to submodule
|
||||
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
|
||||
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
||||
"""
|
||||
suffix: str
|
||||
target_module: ParallelModule
|
||||
kwargs: Dict[str, Any] = None
|
||||
ignore_if_not_exist: bool = False
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
|
@ -230,13 +230,6 @@ class Policy(ABC):
|
|||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def new_model_class(self) -> Union[Type[nn.Module], None]:
|
||||
"""
|
||||
replace the class of the model to substitute the forward and attributes
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethods
|
||||
def postprocess(self) -> nn.Module:
|
||||
"""
|
||||
|
@ -253,8 +246,9 @@ class Policy(ABC):
|
|||
```python
|
||||
class ModelSharder:
|
||||
|
||||
def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None)
|
||||
def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
|
||||
#TODO: input is a cls or a obj
|
||||
...
|
||||
|
||||
def shard(self) -> None:
|
||||
"""
|
||||
|
@ -262,15 +256,6 @@ class ModelSharder:
|
|||
"""
|
||||
...
|
||||
|
||||
def replace_model_class(self) -> None:
|
||||
"""
|
||||
Replace the model's methods and attributes with our own defined class.
|
||||
|
||||
E.g. we can replace the forward function of the original BertForMaskedLM object
|
||||
with the forward function we define in BertForMaskedLM_ class.
|
||||
"""
|
||||
...
|
||||
|
||||
def replace_module(self) -> None:
|
||||
"""
|
||||
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
|
||||
|
@ -291,7 +276,7 @@ class ShardFormer:
|
|||
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
model = shard_former.shard_model(model, policy=policy)
|
||||
model = shard_former.optimize(model, policy=policy)
|
||||
dataloader = shard_former.shard_dataset(dataset)
|
||||
|
||||
"""
|
||||
|
@ -326,14 +311,69 @@ class ShardFormer:
|
|||
...
|
||||
```
|
||||
|
||||
### Shardformer Convergence
|
||||
## ⌨️ Development Notes
|
||||
|
||||
### Add New Policy to Shardformer
|
||||
|
||||
This section serves as the guideline for writing new policies and register them into `shardformer`.
|
||||
|
||||
- Step 1. Write your own model policy
|
||||
|
||||
You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed.
|
||||
|
||||
- Step 2. Register your policy to the autopolicy
|
||||
|
||||
Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
|
||||
|
||||
For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
|
||||
|
||||
```python
|
||||
_POLICY_LIST = {
|
||||
# BERT
|
||||
"transformers.models.bert.modeling_bert.BertModel":
|
||||
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
|
||||
}
|
||||
```
|
||||
|
||||
### Write Your Unit Testing
|
||||
|
||||
This section serves as the guideline for testing the `shardformer` module.
|
||||
|
||||
- Step 1. Add your model to the model zoo in the test kits.
|
||||
|
||||
Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference.
|
||||
|
||||
- Step 2. Write your unit testing for the model
|
||||
|
||||
Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
|
||||
|
||||
|
||||
- Step 3. Execute your test
|
||||
|
||||
When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.
|
||||
|
||||
```bash
|
||||
# test for your own test file
|
||||
pytest tests/test_shardformer/test_model/<your-file>.py
|
||||
|
||||
# test for the whole shardformer module
|
||||
pytest tests/test_shardformer
|
||||
```
|
||||
|
||||
## 📊 Benchmarking
|
||||
|
||||
### System Performance
|
||||
|
||||
To be added.
|
||||
|
||||
### Convergence
|
||||
|
||||
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
|
||||
|
||||
| accuracy | f1 | loss | GPU number | model shard |
|
||||
| :-----: | :----: | :----: | :----: | :----: |
|
||||
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
|
||||
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
|
||||
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
|
||||
| accuracy | f1 | loss | GPU number | model shard |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
|
||||
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
|
||||
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
|
||||
|
||||
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
|
||||
|
|
|
@ -51,7 +51,7 @@ def train(args):
|
|||
if dist.get_world_size() > 1:
|
||||
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
model = shard_former.shard_model(model)
|
||||
model = shard_former.optimize(model)
|
||||
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
|
|
|
@ -22,9 +22,11 @@ class SubModuleReplacementDescription:
|
|||
r"""
|
||||
Describe how a submodule will be replaced
|
||||
|
||||
suffix (str): used to get the submodule object
|
||||
target_module (ParallelModule): specifies the module class used to replace to submodule
|
||||
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
|
||||
Args:
|
||||
suffix (str): used to get the submodule object
|
||||
target_module (ParallelModule): specifies the module class used to replace to submodule
|
||||
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
|
||||
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
||||
"""
|
||||
suffix: str
|
||||
target_module: ParallelModule
|
||||
|
@ -35,47 +37,37 @@ class SubModuleReplacementDescription:
|
|||
@dataclass
|
||||
class ModulePolicyDescription:
|
||||
r"""
|
||||
Describe how the attributes and parameters will be transformed in a policy
|
||||
Describe how the attributes and parameters will be transformed in a policy.
|
||||
|
||||
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
|
||||
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
|
||||
must receive two arguments: module, process_group. One example is
|
||||
Args:
|
||||
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
|
||||
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
|
||||
must receive only one arguments: module. One example is
|
||||
|
||||
```python
|
||||
def example_replace_weight(module: torch.nn.Module, process_group):
|
||||
weight = module.weight
|
||||
new_weight = shard_rowwise(weight, process_group)
|
||||
module.weight = torch.nn.Parameter(new_weight)
|
||||
```
|
||||
|
||||
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
|
||||
the module to be replaced and the target module used to replacement
|
||||
```python
|
||||
def example_replace_weight(module: torch.nn.Module):
|
||||
weight = module.weight
|
||||
new_weight = shard_rowwise(weight, process_group)
|
||||
module.weight = torch.nn.Parameter(new_weight)
|
||||
```
|
||||
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
|
||||
object which specifies the module to be replaced and the target module used to replacement.
|
||||
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
|
||||
"""
|
||||
attribute_replacement: Dict[str, Any]
|
||||
param_replacement: List[Callable]
|
||||
sub_module_replacement: List[SubModuleReplacementDescription]
|
||||
method_replacement: List[Callable] = None
|
||||
attribute_replacement: Dict[str, Any] = None
|
||||
param_replacement: List[Callable] = None
|
||||
sub_module_replacement: List[SubModuleReplacementDescription] = None
|
||||
method_replacement: Dict[str, Callable] = None
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
r"""
|
||||
The base class for all the policies
|
||||
|
||||
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||
or OPTPolicy for OPT model.
|
||||
|
||||
AutoPolicy:
|
||||
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
||||
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
||||
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
|
||||
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
|
||||
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
|
||||
|
||||
CustomPolicy:
|
||||
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
|
||||
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
|
||||
class for the example.
|
||||
The base class for all the policies. For each different model, it should have a different policy class,
|
||||
like BertPolicy for Bert Model or OPTPolicy for OPT model.
|
||||
|
||||
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
|
||||
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
|
||||
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -106,63 +98,24 @@ class Policy(ABC):
|
|||
def config_sanity_check(self):
|
||||
"""
|
||||
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
|
||||
This method is made abstractmethod with no default implementation because we want to the policy writer
|
||||
to take note of the feature supported by his/her model and policy.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self) -> nn.Module:
|
||||
r"""
|
||||
Perform some preprocessing of the model, like reshaping the embedding layer
|
||||
Perform some preprocessing of the model, like reshaping the embedding layer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
r"""
|
||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||
argument for the modify layer
|
||||
|
||||
Return:
|
||||
Dict for the modify policy,
|
||||
::
|
||||
{
|
||||
origin layer class1 (nn.Module): ModulePolicyDescription(
|
||||
attribute_replacement = {
|
||||
"attribute1": value1,
|
||||
"attribute2": value2,
|
||||
...
|
||||
},
|
||||
param_replacement = [
|
||||
function1,
|
||||
function2,
|
||||
...
|
||||
],
|
||||
sub_module_replacement = [
|
||||
`SubModuleReplacementDescription` description1,
|
||||
`SubModuleReplacementDescription` description2,
|
||||
...
|
||||
]
|
||||
),
|
||||
origin layer class2 (nn.Module): ModulePolicyDescription(
|
||||
...
|
||||
),
|
||||
...
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def new_model_class(self) -> Union[Type[nn.Module], None]:
|
||||
r"""
|
||||
Return the new model class for the new model, None means no need to modify the model class
|
||||
|
||||
Return:
|
||||
New model class
|
||||
|
||||
E.g.
|
||||
```
|
||||
return BertModel_
|
||||
```
|
||||
This method returns the module policy, which is a dictionary. The key is the module name or the module object,
|
||||
and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module
|
||||
will be transformed.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ class BertPolicy(Policy):
|
|||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
|
@ -88,18 +87,16 @@ class BertPolicy(Policy):
|
|||
)
|
||||
]),
|
||||
BertEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
|
@ -121,10 +118,6 @@ class BertPolicy(Policy):
|
|||
),)
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
@ -148,13 +141,10 @@ class BertForPretrainingPolicy(BertPolicy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
|
@ -191,13 +181,10 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
|
@ -230,13 +217,10 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
|
@ -272,14 +256,12 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
@ -297,14 +279,12 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
@ -329,14 +309,12 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
|
|
@ -98,7 +98,6 @@ class BloomPolicy(Policy):
|
|||
"self_attention.num_heads":
|
||||
self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
|
@ -125,7 +124,6 @@ class BloomPolicy(Policy):
|
|||
ModulePolicyDescription(attribute_replacement={
|
||||
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -160,10 +158,6 @@ class BloomPolicy(Policy):
|
|||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
@ -180,13 +174,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
@ -213,13 +204,10 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="score",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
@ -233,17 +221,14 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForTokenClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="classifier",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
|
|
@ -31,23 +31,20 @@ class GPT2Policy(Policy):
|
|||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
|
||||
return {
|
||||
base_policy = {
|
||||
GPT2Model:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
]),
|
||||
GPT2Block:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
|
@ -110,9 +107,6 @@ class GPT2Policy(Policy):
|
|||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
@ -136,13 +130,10 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
@ -169,13 +160,10 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2DoubleHeadsModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
|
|
@ -28,7 +28,7 @@ class LlamaPolicy(Policy):
|
|||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
return {
|
||||
base_policy = {
|
||||
LlamaDecoderLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
|
@ -37,7 +37,6 @@ class LlamaPolicy(Policy):
|
|||
"self_attn.num_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
|
@ -70,14 +69,12 @@ class LlamaPolicy(Policy):
|
|||
],
|
||||
),
|
||||
LlamaModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
|
@ -101,9 +98,6 @@ class LlamaPolicy(Policy):
|
|||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
@ -117,13 +111,10 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
@ -139,13 +130,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
|
|
@ -31,33 +31,28 @@ class OPTPolicy(Policy):
|
|||
|
||||
base_policy = {
|
||||
OPTDecoder:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
]),
|
||||
OPTDecoderLayer:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
]),
|
||||
OPTAttention:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
|
@ -95,9 +90,6 @@ class OPTPolicy(Policy):
|
|||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
@ -116,13 +108,10 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
policy = super().module_policy()
|
||||
new_item = {
|
||||
OPTForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
|
||||
policy.update(new_item)
|
||||
|
|
|
@ -44,36 +44,30 @@ class T5BasePolicy(Policy):
|
|||
|
||||
base_policy = {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5LayerCrossAttention:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5Attention:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
|
@ -83,7 +77,6 @@ class T5BasePolicy(Policy):
|
|||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
|
@ -107,51 +100,44 @@ class T5BasePolicy(Policy):
|
|||
ignore_if_not_exist=True)
|
||||
]),
|
||||
T5LayerFF:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5DenseGatedActDense:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5DenseActDense:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
|
@ -167,9 +153,6 @@ class T5BasePolicy(Policy):
|
|||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||
|
||||
|
@ -185,14 +168,12 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
from transformers import T5Model
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
|
||||
|
@ -202,18 +183,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
@ -235,14 +212,12 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
from transformers import T5EncoderModel
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -28,16 +28,14 @@ class ViTPolicy(Policy):
|
|||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||
|
||||
return {
|
||||
base_policy = {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForReplicatedInput,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForReplicatedInput,
|
||||
)
|
||||
]),
|
||||
ViTLayer:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"attention.attention.num_attention_heads":
|
||||
|
@ -45,7 +43,6 @@ class ViTPolicy(Policy):
|
|||
"attention.attention.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
|
|
|
@ -3,8 +3,6 @@ from dataclasses import dataclass
|
|||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
||||
|
||||
|
@ -15,7 +13,9 @@ class ShardConfig:
|
|||
|
||||
Args:
|
||||
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
|
||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||
enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||
"""
|
||||
tensor_parallel_process_group: ProcessGroup = None
|
||||
enable_fused_normalization: bool = False
|
||||
|
@ -45,4 +45,4 @@ class ShardConfig:
|
|||
Turn on all optimization.
|
||||
"""
|
||||
# you can add all the optimization flag here
|
||||
self.fused_layernorm = True
|
||||
self.enable_fused_normalization = True
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
||||
|
@ -34,7 +32,6 @@ class ModelSharder(object):
|
|||
self.policy.set_model(self.model)
|
||||
self.policy.set_shard_config(self.shard_config)
|
||||
self._preprocess()
|
||||
self._replace_model_class()
|
||||
self._replace_module()
|
||||
self._postprocess()
|
||||
|
||||
|
@ -44,27 +41,6 @@ class ModelSharder(object):
|
|||
def _postprocess(self) -> None:
|
||||
self.model = self.policy.postprocess()
|
||||
|
||||
def _replace_model_class(self,) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
|
||||
e.g.
|
||||
::
|
||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
new_model_class = self.policy.new_model_class()
|
||||
if new_model_class is None:
|
||||
return
|
||||
|
||||
for key in new_model_class.__dict__.keys():
|
||||
if hasattr(self.model.__class__, key):
|
||||
setattr(
|
||||
self.model.__class__,
|
||||
key,
|
||||
getattr(new_model_class, key),
|
||||
)
|
||||
|
||||
def _replace_module(self,) -> None:
|
||||
r"""
|
||||
Replace the module according to the policy, and replace the module one by one
|
||||
|
@ -73,19 +49,18 @@ class ModelSharder(object):
|
|||
model (:class:`torch.nn.Module`): The model to shard
|
||||
"""
|
||||
module_descriptions = self.policy.module_policy()
|
||||
for module_description in module_descriptions.items():
|
||||
origin_layer_cls = module_description[0]
|
||||
attr_replacement = module_description[1].attribute_replacement
|
||||
param_replacement = module_description[1].param_replacement
|
||||
sub_module_replacement = module_description[1].sub_module_replacement
|
||||
method_replacement = module_description[1].method_replacement
|
||||
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
||||
for layer_cls, module_description in module_descriptions.items():
|
||||
attr_replacement = module_description.attribute_replacement
|
||||
param_replacement = module_description.param_replacement
|
||||
sub_module_replacement = module_description.sub_module_replacement
|
||||
method_replacement = module_description.method_replacement
|
||||
self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
|
||||
method_replacement, sub_module_replacement)
|
||||
|
||||
def _recursive_replace_layer(
|
||||
self,
|
||||
module: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
origin_cls: Union[str, nn.Module],
|
||||
attr_replacement: Dict[str, Any],
|
||||
param_replacement: List[Callable],
|
||||
method_replacement: Dict[str, Callable],
|
||||
|
@ -95,17 +70,25 @@ class ModelSharder(object):
|
|||
Reverse the replace layer operation
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
origin_cls (:class:`transformers.model`): The origin layer class
|
||||
layer (torch.nn.Module): The object of layer to shard
|
||||
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
|
||||
attr_replacement (Dict): The attribute dict to modify
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in polic
|
||||
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
||||
"""
|
||||
if module.__class__ == origin_cls:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
self._replace_param(module, param_replacement)
|
||||
self._replace_method(module, method_replacement)
|
||||
self._replace_sub_module(module, sub_module_replacement)
|
||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
||||
(module.__class__ == origin_cls):
|
||||
if attr_replacement is not None:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
|
||||
if param_replacement is not None:
|
||||
self._replace_param(module, param_replacement)
|
||||
|
||||
if method_replacement is not None:
|
||||
self._replace_method(module, method_replacement)
|
||||
|
||||
if sub_module_replacement is not None:
|
||||
self._replace_sub_module(module, sub_module_replacement)
|
||||
|
||||
for name, child in module.named_children():
|
||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
|
||||
|
@ -138,13 +121,10 @@ class ModelSharder(object):
|
|||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||
"""
|
||||
# TODO: support parameter shard
|
||||
pass
|
||||
for param_func in param_replacement:
|
||||
param_func(module)
|
||||
|
||||
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
||||
if method_replacement is None:
|
||||
return
|
||||
|
||||
for method_name, new_method in method_replacement.items():
|
||||
# bind the new method to the module
|
||||
setattr(module, method_name, new_method.__get__(module, module.__class__))
|
||||
|
@ -158,8 +138,8 @@ class ModelSharder(object):
|
|||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
|
||||
Args:
|
||||
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
|
||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
org_layer (torch.nn.Module): The origin layer object to shard
|
||||
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
|
||||
|
||||
"""
|
||||
for description in sub_module_replacement:
|
||||
|
|
|
@ -22,27 +22,19 @@ class ShardFormer:
|
|||
colossalai.launch_from_torch(config={})
|
||||
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_size=2,
|
||||
tensor_parallel_mode='1d',
|
||||
)
|
||||
shard_config = ShardConfig()
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
model = shard_former.shard_model(org_model)
|
||||
model = shard_former.optimize(org_model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, shard_config: ShardConfig):
|
||||
"""
|
||||
Do two things:
|
||||
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
|
||||
2. serve as a store for
|
||||
"""
|
||||
self.coordinator = DistCoordinator()
|
||||
self.shard_config = shard_config
|
||||
|
||||
def shard_model(self, model: nn.Module, policy: Policy = None):
|
||||
def optimize(self, model: nn.Module, policy: Policy = None):
|
||||
r"""
|
||||
The function is used to shard the PyTorch model.
|
||||
This method will optimize the model based on the given policy.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Model`): the origin huggingface model
|
||||
|
|
|
@ -11,7 +11,7 @@ def build_model(model_fn):
|
|||
shard_config = ShardConfig(enable_fused_normalization=True)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.shard_model(model_copy).cuda()
|
||||
sharded_model = shard_former.optimize(model_copy).cuda()
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create and shard model
|
||||
model = model_fn().cuda()
|
||||
sharded_model = shardformer.shard_model(model)
|
||||
sharded_model = shardformer.optimize(model)
|
||||
|
||||
# add ddp
|
||||
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
|
||||
|
|
Loading…
Reference in New Issue