Browse Source

[shardformer]: Feature/shardformer, add some docstring and readme (#3816)

* init shardformer code structure

* add implement of sharder (inject and replace)

* add implement of replace layer to colossal layer

* separate different layer policy, add some notion

* implement 1d and 2d slicer, can tell col or row

* fix bug when slicing and inject model

* fix some bug; add inference test example

* add share weight and train example

* add train

* add docstring and readme

* add docstring for other files

* pre-commit
pull/4157/head
FoolPlayer 2 years ago committed by Frank Lee
parent
commit
8cc11235c0
  1. 2
      colossalai/nn/layer/parallel_1d/_operation.py
  2. 9
      colossalai/nn/layer/parallel_1d/layers.py
  3. 177
      colossalai/shardformer/README.md
  4. 12
      colossalai/shardformer/model/modeling_bert.py
  5. 19
      colossalai/shardformer/policies/autopolicy.py
  6. 108
      colossalai/shardformer/policies/basepolicy.py
  7. 113
      colossalai/shardformer/policies/bert.py
  8. 2
      colossalai/shardformer/shard/shardconfig.py
  9. 179
      colossalai/shardformer/shard/sharder.py
  10. 34
      colossalai/shardformer/shard/shardmodel.py
  11. 109
      colossalai/shardformer/shard/slicer.py
  12. 6
      colossalai/shardformer/test/config.py
  13. 87
      colossalai/shardformer/test/test.py
  14. 32
      colossalai/shardformer/utils/utils.py

2
colossalai/nn/layer/parallel_1d/_operation.py

@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try:
@ -72,6 +73,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])

9
colossalai/nn/layer/parallel_1d/layers.py

@ -469,7 +469,8 @@ class Linear1D_Col(ParallelLayer):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
# Parameters.
# Initialize weight.
@ -612,7 +613,8 @@ class Linear1D_Row(ParallelLayer):
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
# Parameters.
# Initialize weight.
@ -884,7 +886,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

177
colossalai/shardformer/README.md

@ -0,0 +1,177 @@
## ShardFormer
### Intro
Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy.
### Quick start
1. Usage
- Use
``` python
from colossalai.shardformer.shard.shardmodel import ShardModel
from transformers import BertForMaskedLM
# create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
shardmodel = ShardModel(model).model
# custom policy:
from xxx import <POLICYCLASS>
shardmodel = ShardModel(model, <POLICYCLASS>).model
# do angthing as normal
...
```
- Policy
If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model.
You should do:
1. Inherit Policy class
2. Overwrite argument_policy method
- In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers.
3. Overwrite inject_policy method [Optional]
- If you need to modify the forward or backward progress.
4. Overwrite or add the param recording functions
- These function use suffix to record the path of weight or bias for the layer.
5. Overwrite binding
More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
Args:
model_config: The config of transformer model
shard_setting: The config of distributed model
Return:
Dict for the modify policy,
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
"""
return ()
@staticmethod
def binding_policy() -> Dict:
"""
Return the dict for the binding model
"""
return NotImplementedError
@staticmethod
def attn_in() -> List:
"""
Attention qkv layer
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
@staticmethod
def attn_out() -> List:
"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_in() -> List:
"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_out() -> List:
"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def embedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
```
2. Simple example
``` shell
# inference
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference
# train
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train
```

12
colossalai/shardformer/model/modeling_bert.py

@ -1,12 +1,14 @@
from typing import Any, Dict, List, Type
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import Any, Dict, List, Type
from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput
class BertForMaskedLM_(BertForMaskedLM):
def forward(
self,
input_ids=None,
@ -23,7 +25,7 @@ class BertForMaskedLM_(BertForMaskedLM):
return_dict=None,
**kwargs,
):
print("[Inject OK] Injected forward method")
# print("[Inject OK] Injected forward method")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
@ -48,7 +50,7 @@ class BertForMaskedLM_(BertForMaskedLM):
# if input_ids is not None:
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:

19
colossalai/shardformer/policies/autopolicy.py

@ -1,7 +1,8 @@
import torch.nn as nn
def build_policies():
"""
r"""
Build the policies for the model
Return:
@ -10,31 +11,37 @@ def build_policies():
auto_policy_dict = {}
from transformers.models.bert.modeling_bert import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
return auto_policy_dict
def get_autopolicy(model:nn.Module):
"""
def get_autopolicy(model: nn.Module):
r"""
Return the auto policy for the model
Args:
model: The model to be used
model (:class:`nn.Module`): The model to get the auto policy
Return:
The auto policy for the model
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}")
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
)
return policy
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)

108
colossalai/shardformer/policies/basepolicy.py

@ -1,28 +1,38 @@
# part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from typing import Any, Dict, List, Type, Tuple, Callable
from transformers import AutoConfig
from dataclasses import dataclass, field
import colossalai.nn as col_nn
@dataclass
class Argument:
attr_dict : Dict[str, Any]
param_funcs : List[Callable]
binding_layers : List[nn.Module] = field(default_factory=list)
r"""
The argument class for the policy
Args:
attr_dict (Dict[str, Any]): The dict for the param setting
param_funcs (:class:`List[Callable]`): The list for the param functions
"""
attr_dict: Dict[str, Any]
param_funcs: List[Callable]
@dataclass
class Layer:
"""
r"""
The layer object for the policy
Args:
weight: The weight name of the layer
bias: The bias name of the layer
replace_layer: The layer to replace the original layer
ignore: Whether to ignore this layer if it is not in the model
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
"""
weight: str = None
bias: str = None
@ -32,45 +42,55 @@ class Layer:
@dataclass
class Col_Layer(Layer):
"""
r"""
Class for col shard layer in MegatronLM
Args:
gather_output (bool): Whether to gather the output of the layer
"""
gather_output: bool = False
@dataclass
class Row_Layer(Layer):
"""
r"""
Class for col shard layer in MegatronLM
"""
pass
class Policy():
"""
r"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
shardformer already defined some policies for huggingface model, just set custom_policy = None
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method inject_policy
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
CustomPolicy:
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
class for the example.
"""
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args:
model_config: The config of transformer model
shard_setting: The config of distributed model
model_config (:class:`tansformer.Config`): The config of transformer model
shard_config (:class:`ShardConfig`): The config for sharding model
Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
@ -102,32 +122,50 @@ class Policy():
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
r"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
@staticmethod
def binding_policy() -> Dict:
r"""
Return the dict for the binding model
Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
@staticmethod
def attn_in() -> List:
"""
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
@staticmethod
def attn_out() -> List:
"""
r"""
Attention output projection layer
Returns:
@ -135,10 +173,9 @@ class Policy():
"""
return NotImplementedError
@staticmethod
def mlp_in() -> List:
"""
r"""
h -> 4h mlp layer
Returns:
@ -146,10 +183,9 @@ class Policy():
"""
return NotImplementedError
@staticmethod
def mlp_out() -> List:
"""
r"""
4h -> h mlp layer
Returns:
@ -157,24 +193,20 @@ class Policy():
"""
return NotImplementedError
@staticmethod
def embedding()->List:
"""
def embedding() -> List:
r"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding()->List:
"""
def unembedding() -> List:
r"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object

113
colossalai/shardformer/policies/bert.py

@ -1,56 +1,57 @@
from typing import Dict, List, Tuple, Type, Any, Callable
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.nn as col_nn
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
from dataclasses import dataclass
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class BertPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]:
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
return {
BertLayer: Argument(
attr_dict = {
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
},
param_funcs = [
BertPolicy.attn_in,
BertPolicy.attn_out,
BertPolicy.mlp_in,
BertPolicy.mlp_out
]
),
BertEmbeddings: Argument(
attr_dict = {
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size,
},
param_funcs = [
BertPolicy.embedding,
],
binding_layers = [
BertLMPredictionHead,
]
),
BertLMPredictionHead: Argument(
attr_dict = {
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs = [
BertPolicy.unembedding,
]
)
BertLayer:
Argument(
attr_dict={
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
},
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
BertEmbeddings:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
},
param_funcs=[
BertPolicy.embedding,
]),
BertLMPredictionHead:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs=[
BertPolicy.unembedding,
])
}
@staticmethod
def binding_policy() -> Dict:
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod
@ -89,7 +90,6 @@ class BertPolicy(Policy):
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
@ -111,7 +111,7 @@ class BertPolicy(Policy):
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(
Col_Layer(
weight="intermediate.dense.weight",
bias="intermediate.dense.bias",
replace_layer=col_nn.Linear1D_Col,
@ -130,12 +130,10 @@ class BertPolicy(Policy):
@staticmethod
def embedding() -> List:
return [
Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)
]
return [Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
@staticmethod
def unembedding() -> List:
@ -148,16 +146,21 @@ class BertPolicy(Policy):
)
]
from transformers import BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Dict:
return {}

2
colossalai/shardformer/shard/shardconfig.py

@ -10,7 +10,7 @@ class ShardConfig:
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend="nccl"
backend = "nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False

179
colossalai/shardformer/shard/sharder.py

@ -1,56 +1,59 @@
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable
from .shardconfig import ShardConfig
from dataclasses import dataclass
from ..policies.basepolicy import Policy, Layer
from ..policies.autopolicy import get_autopolicy
from .slicer import Slicer
from ..utils.utils import hasattr_, setattr_, getattr_
import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger
import os
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Layer, Policy
from ..utils.utils import getattr_, hasattr_, setattr_
from .shardconfig import ShardConfig
from .slicer import Slicer
logger = get_dist_logger()
class ModelSharder(object):
"""
r"""
Shard the original huggingface model according to the policy
Args:
policy: The policy to shard the model
model: The model to shard
dist_setting: The setting of distributed model
policy (:class:`Policy`): The policy to shard the model
model (:class:`torch.Module`): The model to shard
shard_config: The setting of distributed model
"""
def __init__(
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
) -> None:
shard_config: ShardConfig = None, # TODO
) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.slicer = Slicer(shard_config)
self.shard_config = shard_config
self.model_config = self.model.config
self.binding_map = {}
def shard(self) -> None:
self.inject_model(self.model)
self.replace_layer(self.model)
self.bind_layer(self.model)
def inject_model(
self,
model: nn.Module,
) -> None:
"""
self,
model: nn.Module,
) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
e.g.
::
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
@ -64,49 +67,43 @@ class ModelSharder(object):
setattr(
model.__class__,
key,
getattr(shard_model_cls,key),
getattr(shard_model_cls, key),
)
else:
raise NotImplementedError(f"{model.__class__} is not implemented so far")
def replace_layer(
self,
model: nn.Module,
) -> None:
"""
self,
model: nn.Module,
) -> None:
r"""
Replace the layer according to the policy, and replace the layer one by one
Args:
layer: The layer to shard
model (:class:`torch.nn.Module`): The layer to shard
"""
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
for argument_policy in argument_policies.items():
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
binding_layers = argument_policy[1].binding_layers
# if binding_layer is not None:
# self.binding_map[origin_layer_cls] = binding_layer
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers)
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
def reverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
attr_dict: Dict[str, Any],
param_funcs: List[Callable],
binding_layers: List[nn.Module]
) -> None:
"""
self,
layer: nn.Module,
origin_cls: nn.Module,
attr_dict: Dict[str, Any],
param_funcs: List[Callable],
) -> None:
r"""
Reverse the replace layer operation
Args:
layer: The object of layer to shard
origin_cls: The origin layer class
attr_dict: The attribute dict to modify
policy_cls: The policy class
layer (:class:`torch.nn.Module`): The object of layer to shard
origin_cls (:class:`transformers.model`): The origin layer class
attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class
"""
for name, child in layer.named_children():
if child.__class__ == origin_cls:
@ -115,25 +112,23 @@ class ModelSharder(object):
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs, binding_layers)
self.shard_one_layer(child, param_funcs)
continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers)
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer
def shard_one_layer(
self,
org_layer: nn.Module,
param_funcs: List[Callable],
binding_layers: List[nn.Module]
) -> None:
"""
self,
org_layer: nn.Module,
param_funcs: List[Callable],
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer: The origin layer object to shard
param_funcs: The function list to get shard information in policy class
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
"""
# print(org_layer)
@ -148,7 +143,7 @@ class ModelSharder(object):
ignore = policy_layer.ignore
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
print(gather_output)
# print(gather_output)
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
@ -172,67 +167,81 @@ class ModelSharder(object):
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
# save the binding information
for binding_layer in binding_layers:
self.binding_map[binding_layer] = dict(weight=weight, bias=bias)
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
# create new object to replace the origin layer
if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True)
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output)
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
else:
raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
raise NotImplementedError(
f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
# do not replace the layer object, just replace the weight and bias
else:
self.set_param(org_layer, layer_attr, weight, bias)
def set_param(
self,
layer: Any,
layer_attr: str = "",
weight: torch.Tensor = None,
bias: torch.Tensor = None
) -> None:
"""
def set_param(self,
layer: Any,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
layer_attr: str = "") -> None:
r"""
Reset the weight and bias of the layer object
Args:
layer: The layer object
layer_attr: The attribute name of the layer
weight: The weight of the layer
bias: The bias of the layer
layer (:class:`torch.nn.Module`): The layer object
layer_attr (str): The attribute name of the layer
weight (:class:`torch.Tensor`): The weight of the layer
bias (:class:`torch.Tensor`): The bias of the layer
"""
assert weight is not None or bias is not None
if weight is not None:
setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight))
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
self.set_layer_size(layer, layer_attr, weight.shape)
if bias is not None:
setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias))
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
"""
r"""
Set the layer attribute
Args:
layer: The layer object
layer_attr: The attribute name of the layer
size: Torch.size
layer (:class:`torch.nn.Module`): The layer object
layer_attr (str): The attribute name of the layer
size (:class:`torch.Size`): The size of the tensor
"""
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
attrs = ["out_features", "in_features"]
for i, attr in enumerate(attrs):
if hasattr_(layer, f"{layer_attr}.{attr}"):
setattr_(layer, f"{layer_attr}.{attr}", size[i])
def bind_layer(self, model: nn.Module) -> None:
r"""
Bind the layer according to the binding policy
Args:
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)
setattr_(model, k, param)
setattr_(model, v, param)

34
colossalai/shardformer/shard/shardmodel.py

@ -1,46 +1,48 @@
import os
from contextlib import suppress
from dataclasses import dataclass
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
import torch.distributed as dist
from dataclasses import dataclass
from contextlib import suppress
from colossalai.tensor.d_tensor.layout import Layout
from ..policies.basepolicy import Policy
from .sharder import ModelSharder
from .shardconfig import ShardConfig
from .sharder import ModelSharder
class ShardModel(object):
"""
The class for sharding the huggingface model, self.model is the sharded model
r"""
The class for sharding the huggingface model, ''self.model'' is the sharded model
Just creat a new ShardModel object to shard huggingface model
Args:
model: the origin huggingface model
dist_config: the config for distribute information
custom_policy: the custom policy for sharding
model (:class:`torch.nn.Model`): the origin huggingface model
dist_config (:class:`ShardConfig`): the config for distribute information
custom_policy (:class:`Policy`): the custom policy for sharding
"""
def __init__(
self,
model: nn.Module,
shard_config: ShardConfig = None, # TODO
custom_policy: Policy = None,
) -> None:
self,
model: nn.Module,
shard_config: ShardConfig = None, # TODO
custom_policy: Policy = None,
) -> None:
self.model = model
self.shard_config = shard_config
self.policy = custom_policy
# self.layout=, # TODO
sharder=ModelSharder(
sharder = ModelSharder(
model=self.model,
policy=self.policy,
shard_config=self.shard_config,
)
sharder.shard()
def set_environ(self) -> None:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"

109
colossalai/shardformer/shard/slicer.py

@ -1,40 +1,40 @@
import os
from typing import Dict, Tuple
from dataclasses import dataclass
from typing import Dict, Tuple
import torch
import torch.distributed as dist
from ..policies.basepolicy import Layer, Col_Layer, Row_Layer
from .shardconfig import ShardConfig
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shardconfig import ShardConfig
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
class Slicer():
def __init__(
self,
shardconfig: ShardConfig #TODO
self,
shardconfig: ShardConfig #TODO
) -> None:
self.shardconfig = shardconfig
def slice_weight_bias(
self,
weight: torch.Tensor,
bias: torch.Tensor,
policy_layer_cls: Layer,
):
"""
r"""
Slice the weight and bias according to policy layer cls
Layer -> do nothing
Col_Layer -> slice the weight and bias along dim 1
Row_Layer -> slice the weight along dim 0 and do not slice bias
``Layer`` -> do nothing
``Col_Layer`` -> slice the weight and bias along dim 1
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
Args:
weight: The weight of the layer
bias: The bias of the layer
policy_layer_class: The class represent how to slice the tensor
weight (:class:`torch.nn.Module`): The weight of the layer
bias: (:class:`torch.nn.Module`): The bias of the layer
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
"""
if policy_layer_cls == Layer:
return weight, bias
@ -47,50 +47,19 @@ class Slicer():
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
return weight, bias
def slice_weight(
self,
weight: torch.Tensor,
policy_layer_cls: Layer,
) -> torch.Tensor:
"""
Slice the weight and bias according to the shardconfig
Args:
weight: The weight of the layer
bias: The bias of the layer
policy_layer_class: The class represent how to slice the tensor
"""
if weight is not None:
dim = dim_mapping[policy_layer_cls]
weight = self.slice_tensor(weight, dim, False)
return weight
def slice_bias(
self,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Slice the bias according to the shardconfig
Args:
bias: The bias of the layer
"""
assert bias is not None, "The bias is None"
if bias is not None:
bias = self.slice_tensor(bias, 1, True)
return bias
def slice_tensor(
self,
tensor_in: torch.Tensor,
dim: int,
is_bias: bool,
) -> torch.Tensor:
"""
r"""
Slice tensor according to the config
Args:
tensor_in (:class:`torch.Tensor`): The tensor to slice
dim (int): The dimension to slice
is_bias (bool): Whether the tensor is bias
"""
if tensor_in is None:
return None
@ -99,69 +68,75 @@ class Slicer():
else:
return self.slice_1d(tensor_in)
def slice_2d(
self,
tensor: torch.Tensor,
dim: int,
) -> torch.Tensor:
"""
r"""
Slice the 2D tensor
Args:
tensor: The tensor to slice
tensor (:class:`torch.Tensor`): The tensor to slice
dim (int): The dimension to slice
"""
assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor"
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0:
return self.slice_row(tensor)
elif dim == 1:
return self.slice_col(tensor)
def slice_1d(
self,
tensor: torch.Tensor,
dim: int = None,
) -> torch.Tensor:
"""
r"""
Slice the 1D tensor
Args:
tensor: The tensor to slice
tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx]
return tensor[down_idx:up_idx].contiguous()
def slice_col(
self,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
r"""
Slice the tensor in column
Args:
tensor: The tensor to slice
tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx,:]
return tensor[down_idx:up_idx, :].contiguous()
def slice_row(
self,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
r"""
Slice the tensor in column
Args:
tensor: The tensor to slice
tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[:,down_idx:up_idx]
return tensor[:, down_idx:up_idx].contiguous()

6
colossalai/shardformer/test/config.py

@ -1,5 +1 @@
parallel = dict(
data=1,
pipeline=1,
tensor=dict(size=2, mode='1d')
)
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))

87
colossalai/shardformer/test/test.py

@ -1,23 +1,51 @@
from transformers import AutoTokenizer
from transformers import BertForMaskedLM
import argparse
import inspect
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import colossalai
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.utils import get_current_device, print_rank_0
from colossalai.logging import get_dist_logger
from colossalai.shardformer.shard.shardconfig import ShardConfig
import inspect
import argparse
import torch.nn as nn
import os
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
return parser.parse_args()
def load_data():
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"].select(range(500))
test_dataset = tokenized_datasets["test"].select(range(100))
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader
def inference(model: nn.Module):
# print(model)
print(model)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
@ -25,13 +53,48 @@ def inference(model: nn.Module):
outputs = model(**inputs)
print(outputs)
def train(model: nn.Module, num_epoch: int = 2):
train_dataloader, eval_dataloader = load_data()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
progress_bar = tqdm(range((num_epoch) * len(train_dataloader)))
criterion = nn.CrossEntropyLoss()
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
progress_bar.update(1)
train_loss = loss
loss = 0.0
for batch in eval_dataloader:
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
shard_config = ShardConfig(
rank = int(str(get_current_device()).split(':')[-1]),
world_size= int(os.environ['WORLD_SIZE']),
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
shardmodel = ShardModel(model, shard_config)
inference(shardmodel.model)
if args.mode == "train":
train(shardmodel.model)
elif args.mode == "inference":
inference(shardmodel.model)

32
colossalai/shardformer/utils/utils.py

@ -1,10 +1,10 @@
def hasattr_(obj, attr: str):
"""
r"""
Check whether the object has the multi sublevel attr
Args:
obj: The object to check
attr: The multi level attr to check
obj (object): The object to check
attr (str): The multi level attr to check
"""
attrs = attr.split('.')
for a in attrs:
@ -14,15 +14,16 @@ def hasattr_(obj, attr: str):
return False
return True
def setattr_(obj, attr: str, value, ignore: bool=False):
"""
def setattr_(obj, attr: str, value, ignore: bool = False):
r"""
Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist
Args:
obj: The object to set
attr: The multi level attr to set
value: The value to set
ignore: Whether to ignore when the attr doesn't exist
obj (object): The object to set
attr (str): The multi level attr to set
value (Any): The value to set
ignore (bool): Whether to ignore when the attr doesn't exist
"""
attrs = attr.split('.')
@ -31,18 +32,19 @@ def setattr_(obj, attr: str, value, ignore: bool=False):
obj = getattr(obj, a)
except AttributeError:
if ignore:
return
return
raise AttributeError(f"Object {obj} has no attribute {attr}")
setattr(obj, attrs[-1], value)
def getattr_(obj, attr: str, ignore: bool=None):
"""
def getattr_(obj, attr: str, ignore: bool = None):
r"""
Get the object's multi sublevel attr
Args:
obj: The object to set
attr: The multi level attr to set
ignore: Whether to ignore when the attr doesn't exist
obj (object): The object to set
attr (str): The multi level attr to set
ignore (bool): Whether to ignore when the attr doesn't exist
"""
attrs = attr.split('.')

Loading…
Cancel
Save