[shardformer] Refactor shardformer api (#4001)

* fix an error in readme

* simplify code

* refactor shardformer

* add todo

* remove slicer

* resolve code review
pull/4157/head
FoolPlayer 1 year ago committed by Frank Lee
parent 611971248c
commit d3bc530849

@ -1 +1 @@
from .shard import ShardConfig, shard_model
from .shard import ShardConfig, ShardFormer

@ -1,5 +1,7 @@
import torch.nn as nn
from .basepolicy import Policy
def build_policies():
r"""
@ -41,47 +43,25 @@ def build_policies():
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers.models.llama.modeling_llama import LlamaModel
from .llama import LlamaPolicy
auto_policy_dict[LlamaModel] = LlamaPolicy
from transformers import LlamaForSequenceClassification
from .llama import LlamaForSequenceClassificationPolicy
auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
from transformers import LlamaForCausalLM
from .llama import LlamaForCausalLMPolicy
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
from transformers import BertForMultipleChoice
from .bert import BertForMultipleChoicePolicy
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
from transformers import GPT2Model
from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy
from transformers import GPT2LMHeadModel
from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
t5 = {
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
T5EncoderModel: T5EncoderModelPolicy,
T5Model: T5ModelPolicy,
}
auto_policy_dict.update(t5)
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict
def get_autopolicy(model: nn.Module):
def get_autopolicy(model: nn.Module) -> Policy:
r"""
Return the auto policy for the model
@ -97,7 +77,7 @@ def get_autopolicy(model: nn.Module):
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
)
return policy
return policy()
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining

@ -1,102 +1,65 @@
# part of code modified from https://github.com/tunib-ai/parallelformers
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch.nn as nn
from ..shard.shard_config import ShardConfig
@dataclass
class Argument:
r"""
The argument class for the policy
Args:
attr_dict (Dict[str, Any]): The dict for the param setting
param_funcs (:class:`List[Callable]`): The list for the param functions
"""
attr_dict: Dict[str, Any]
param_funcs: List[Callable]
class ParallelModule():
@dataclass
class Layer:
r"""
The layer object for the policy
Args:
suffix: (str): the suffix of the layer.
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
suffix: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
n_cast: int = None
def __init__(self):
pass
@dataclass
class Col_Layer(Layer):
class SubModuleReplacementDescription:
r"""
Class for col shard layer in tensor parrallel
Describe how a submodule will be replaced
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
gather_output (bool): Whether to gather the output of the layer
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
"""
weight: str = None
bias: str = None
gather_output: bool = False
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
@dataclass
class Row_Layer(Layer):
class ModulePolicyDescription:
r"""
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
"""
weight: str = None
bias: str = None
Describe how the attributes and parameters will be transformed in a policy
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
must receive two arguments: module, process_group. One example is
@dataclass
class Dropout_Layer(Layer):
r"""
Class for dropout layer in tensor parrallel
Args:
p (str): The dropout rate suffix of the layer
"""
p: str = None
@dataclass
class Embedding_Layer(Layer):
r"""
Class for col shard layer in tensor parrallel
```python
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
```
Args:
weight (str): The weight suffix of the layer
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
the module to be replaced and the target module used to replacement
"""
weight: str = None
gather_output: bool = True
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
class Policy():
class Policy(ABC):
r"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
@ -111,137 +74,75 @@ class Policy():
"""
@staticmethod
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
def __init__(self) -> None:
self.model = None
def set_model(self, model: nn.Module) -> None:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Set model as an attribute of the Policy object so that we can access the model's attributes.
Args:
model_config (:class:`tansformer.Config`): The config of transformer model
world_size (int)): The world size of sharding model
model (:class:`nn.Module`): The model to be perform
"""
self.model = model
@abstractmethod
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
@abstractmethod
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
origin layer class1 (nn.Module): ModulePolicyDescription(
attribute_replacement = {
"attribute1": value1,
"attribute2": value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
param_replacement = [
function1,
function2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
],
sub_module_replacement = [
`SubModuleReplacementDescription` description1,
`SubModuleReplacementDescription` description2,
...
]
),
origin layer class2 (nn.Module): ModulePolicyDescription(
...
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
r"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return None
@staticmethod
def binding_policy() -> Union[Dict[str, str], None]:
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
r"""
Return the dict for the binding model, None means no need to bind
Return the new model class for the new model, None means no need to modify the model class
Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return None
@staticmethod
def attn_in() -> Union[List, None]:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return None
@staticmethod
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return None
@staticmethod
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return None
@staticmethod
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return None
@staticmethod
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
New model class
Return:
List[Layer]: List of layer object
E.g.
```
return BertModel_
```
"""
return None
@staticmethod
def unembedding() -> Union[List, None]:
@abstractmethod
def postprocess(self) -> nn.Module:
r"""
Partially slice the embedding layer, None means there is no unembedding layer
Return:
List[Layer]: List of layer object
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
return None

@ -1,220 +1,77 @@
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer
from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ParallelModule():
def __init__(self):
pass
class BertPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
def preprocess(self, shard_config: ShardConfig = None):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self, shard_config: ShardConfig = None):
return {
BertLayer:
Argument(
attr_dict={
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
"attention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
},
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
BertEmbeddings:
Argument(
attr_dict={
# 1. shard vocab size
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
},
param_funcs=[
BertPolicy.embedding,
]),
}
@staticmethod
def attn_in():
return [
Col_Layer(
suffix="attention.self.query",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="attention.self.key",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="attention.self.value",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Dropout_Layer(
suffix="attention.self.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
),
Col_Layer(
suffix="crossattention.self.query",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
suffix="crossattention.self.key",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
suffix="crossattention.self.value",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
def attn_out():
return [
Row_Layer(
suffix="attention.output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
Dropout_Layer(
suffix="attention.output.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
),
Row_Layer(
suffix="crossattention.output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
ignore=True,
),
]
@staticmethod
def mlp_in():
return [
Col_Layer(
suffix="intermediate.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
]
@staticmethod
def mlp_out():
return [
Row_Layer(
suffix="output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
Dropout_Layer(
suffix="output.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)
]
@staticmethod
def embedding():
return [Col_Layer(
suffix="word_embeddings",
weight="weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
@staticmethod
def unembedding():
return [
Col_Layer(
suffix="decoder",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
# BertModel
class BertModelPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForPretraining
class BertForPretrainingPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=ParallelModule,
),
])
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
def new_model_class(self):
# do nothing
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
# return (BertForMaskedLM, BertForMaskedLM_)
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
def __init__(self) -> None:
super().__init__()
# BertLMHeadModel
@ -231,36 +88,5 @@ class BertLMHeadModelPolicy(BertPolicy):
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
def __init__(self) -> None:
super().__init__()

@ -1,5 +1,5 @@
from .shard_config import ShardConfig
from .sharder import ModelSharder, shard_model
from .slicer import Slicer
from .sharder import ModelSharder
from .shardformer import ShardFormer
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer']

@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import List, Literal
__all__ = ['ShardConfig']
@ -9,10 +10,18 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
rank (int): The rank of local process
world_size (int): The world size of the distributed process
data_parallel_size (int): The size of data parallel
tensor_parallel_size (int): The size of tensor parallel
pipeline_parallel_size (int): The size of pipeline parallel
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int = None
world_size: int = None
data_parallel_size: int
tensor_parallel_size: int
pipeline_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True
gather_output: bool = True

@ -4,11 +4,12 @@ import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
from ..utils.utils import getattr_, hasattr_, setattr_
from ..policies.basepolicy import Policy
from ..utils.utils import setattr_
from .shard_config import ShardConfig
from .slicer import Slicer
__all__ = ['ModelSharder', 'shard_model']
@ -28,20 +29,23 @@ class ModelSharder(object):
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
) -> None:
pg_manager: ProcessGroupManager = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.slicer = Slicer(shard_config)
self.shard_config = shard_config
self.model_config = self.model.config
self.pg_manager = pg_manager
def shard(self) -> None:
self.reshape_embedding()
self.inject_model(self.model)
self.replace_layer(self.model)
self.bind_layer(self.model)
r"""
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.preprocess()
self.replace_model_class()
self.replace_module()
self.postprocess()
def reshape_embedding(self,) -> None:
def reshape_embedding(self) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
@ -52,10 +56,13 @@ class ModelSharder(object):
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def inject_model(
self,
model: nn.Module,
) -> None:
def preprocess(self) -> None:
self.model = self.policy.preprocess(self.shard_config)
def postprocess(self) -> None:
self.model = self.policy.postprocess()
def replace_model_class(self,) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
@ -64,49 +71,43 @@ class ModelSharder(object):
::
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
if inject_policy is None:
new_model_class = self.policy.new_model_class()
if new_model_class is None:
return
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]
if model.__class__ == org_model_cls:
for key in shard_model_cls.__dict__.keys():
if hasattr(model.__class__, key):
setattr(
model.__class__,
key,
getattr(shard_model_cls, key),
)
else:
raise NotImplementedError(f"{model.__class__} is not implemented so far")
for key in new_model_class.__dict__.keys():
if hasattr(self.model.__class__, key):
setattr(
self.model.__class__,
key,
getattr(new_model_class, key),
)
def replace_layer(
self,
model: nn.Module,
) -> None:
def replace_module(self,) -> None:
r"""
Replace the layer according to the policy, and replace the layer one by one
Replace the module according to the policy, and replace the module one by one
Args:
model (:class:`torch.nn.Module`): The layer to shard
model (:class:`torch.nn.Module`): The model to shard
"""
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
for argument_policy in argument_policies.items():
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
def traverse_replace_layer(
print(self.policy)
module_descriptions = self.policy.module_policy(self.shard_config)
print(f"*******{module_descriptions}")
for module_description in module_descriptions.items():
origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement
param_replacement = module_description[1].param_replacement
sub_module_replacement = module_description[1].sub_module_replacement
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
sub_module_replacement)
def _recursive_replace_layer(
self,
layer: nn.Module,
module: nn.Module,
origin_cls: nn.Module,
attr_dict: Dict[str, Any],
param_funcs: List[Callable],
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
sub_module_replacement: List[Callable],
) -> None:
r"""
Reverse the replace layer operation
@ -114,169 +115,69 @@ class ModelSharder(object):
Args:
layer (:class:`torch.nn.Module`): The object of layer to shard
origin_cls (:class:`transformers.model`): The origin layer class
attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class
attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
"""
if layer.__class__ == origin_cls:
for k, v in attr_dict.items():
setattr_(layer, k, v, ignore=True)
self.shard_one_layer(layer, param_funcs)
for name, child in layer.named_children():
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer
def shard_one_layer(
if module.__class__ == origin_cls:
self._replace_attr(module, attr_replacement)
self._replace_param(module, param_replacement)
self._replace_sub_module(module, sub_module_replacement)
for name, child in module.named_children():
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
sub_module_replacement)
def _replace_attr(
self,
org_layer: nn.Module,
param_funcs: List[Callable],
module: nn.Module,
attr_replacement: Dict[str, Any],
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Replace the attribute of the layer
Args:
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
layer (:class:`torch.nn.Module`): The object of layer to shard
attr_replacement (Dict): The attribute dict to modify
"""
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
suffix = policy_layer.suffix
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
reversed = policy_layer.reversed
n_cast = policy_layer.n_cast
assert replace_layer_cls is not None, 'replace_layer should not be None'
# create new object to replace the origin layer
# Linear
suffix_layer = getattr_(org_layer, suffix, ignore=True)
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
if suffix_layer is None and ignore:
continue
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
weight = None
bias = None
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
weight = getattr_(org_layer, weight_attr)
else:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
for k, v in attr_replacement.items():
setattr_(module, k, v, ignore=True)
if bias_attr is not None:
if hasattr_(org_layer, bias_attr):
bias = getattr_(org_layer, bias_attr)
else:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
# set the sliced weight and bias to the new nn_col layer
assert weight is not None or bias is not None
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
gather_output = policy_layer.gather_output and self.shard_config.gather_output
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
elif replace_layer_cls.__name__ == "Embedding1D":
gather_output = policy_layer.gather_output
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
gather_output=gather_output)
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
# self.set_param(replace_layer, weight, bias)
else:
raise NotImplementedError(
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
# dropout
elif isinstance(policy_layer, Dropout_Layer):
p_attr = suffix + '.' + policy_layer.p
p = getattr_(org_layer, p_attr, ignore=True)
replace_layer = replace_layer_cls(p)
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
else:
raise NotImplementedError(
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
def set_param(self,
layer: Any,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
layer_attr: str = "") -> None:
def _replace_param(
self,
module: nn.Module,
param_replacement: List[Callable],
) -> None:
r"""
Reset the weight and bias of the layer object
Replace the parameter of the layer
Args:
layer (:class:`torch.nn.Module`): The layer object
layer_attr (str): The attribute name of the layer
weight (:class:`torch.Tensor`): The weight of the layer
bias (:class:`torch.Tensor`): The bias of the layer
layer (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy
"""
assert weight is not None or bias is not None
if weight is not None:
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
self.set_layer_size(layer, layer_attr, weight.shape)
if bias is not None:
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
# TODO: support parameter shard
pass
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[Callable],
) -> None:
r"""
Set the layer attribute
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
layer (:class:`torch.nn.Module`): The layer object
layer_attr (str): The attribute name of the layer
size (:class:`torch.Size`): The size of the tensor
"""
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
attrs = ["out_features", "in_features"]
for i, attr in enumerate(attrs):
if hasattr_(layer, f"{layer_attr}.{attr}"):
setattr_(layer, f"{layer_attr}.{attr}", size[i])
def bind_layer(self, model: nn.Module) -> None:
r"""
Bind the layer according to the binding policy
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
Args:
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
if binding_map is None:
return
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)
setattr_(model, k, param)
setattr_(model, v, param)
for description in sub_module_replacement:
suffix = description.suffix
target_module = description.target_module
kwargs = description.kwargs
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.
assert target_module is not None, 'target_module should not be None'
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
# TODO: init shard_config automatically
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model
# TODO: integrate with new layer
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
replace_layer = None
setattr_(org_layer, suffix, replace_layer)

@ -0,0 +1,77 @@
import torch.nn as nn
from torch.utils.data import Dataset
from colossalai.cluster import DistCoordinator, ProcessGroupManager
from ..policies.basepolicy import Policy
from .shard_config import ShardConfig
from .sharder import ModelSharder
class ShardFormer:
"""
Parallelize model based on the given config and policy
Example:
```python
from colossalai.shardformer import ShardFormer, ShardConfig
from transformers import BertForMaskedLM
import colossalai
import torch
colossalai.launch_from_torch(config={})
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig(
tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(org_model)
```
"""
def __init__(self, shard_config: ShardConfig):
"""
Do two things:
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
2. serve as a store for
"""
self.coordinator = DistCoordinator()
self.shard_config = shard_config
self.pg_manager = None
def init_distributed(self) -> ProcessGroupManager:
"""
Initialize the distributed process group according to the
"""
pg_manager = ProcessGroupManager()
if (self.shard_config.tensor_parallel_mode == '1d'):
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager
return pg_manager
def shard_model(self, model: nn.Module, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
sharder.shard()
return model
def shard_dataset(self, dataset: Dataset):
"""
Shard dataset for DP
"""
pass

@ -1,163 +0,0 @@
import torch
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
from .shard_config import ShardConfig
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}
class Slicer():
def __init__(
self,
shardconfig: ShardConfig #TODO
) -> None:
self.shardconfig = shardconfig
def slice_weight_bias(
self,
weight: torch.Tensor,
bias: torch.Tensor,
policy_layer_cls: Layer,
n_cast: int = None,
reversed: bool = False,
):
r"""
Slice the weight and bias according to policy layer cls
``Layer`` -> do nothing
``Col_Layer`` -> slice the weight and bias along dim 1
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
Args:
weight (:class:`torch.nn.Module`): The weight of the layer
bias: (:class:`torch.nn.Module`): The bias of the layer
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
"""
if policy_layer_cls in [Layer, Dropout_Layer]:
return weight, bias
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
# print(weight.shape, dim)
if policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
bias = self.slice_tensor(bias, 0, True, n_cast)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
elif policy_layer_cls == Embedding_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
if reversed:
weight = weight.transpose(0, 1).contiguous()
return weight, bias
def slice_tensor(
self,
tensor_in: torch.Tensor,
dim: int,
is_bias: bool,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice tensor according to the config
Args:
tensor_in (:class:`torch.Tensor`): The tensor to slice
dim (int): The dimension to slice
is_bias (bool): Whether the tensor is bias
"""
if tensor_in is None:
return None
if not is_bias:
return self.slice_2d(tensor_in, dim, n_cast)
else:
return self.slice_1d(tensor_in, n_cast)
def slice_2d(
self,
tensor: torch.Tensor,
dim: int,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the 2D tensor
Args:
tensor (:class:`torch.Tensor`): The tensor to slice
dim (int): The dimension to slice
"""
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0:
return self.slice_row(tensor, n_cast)
elif dim == 1:
return self.slice_col(tensor, n_cast)
def slice_1d(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the 1D tensor
Args:
tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_col(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the tensor in column
Args:
tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=1).contiguous()
def slice_row(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the tensor in column
Args:
tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()

@ -0,0 +1 @@
from .utils import getattr_, hasattr_, setattr_
Loading…
Cancel
Save