mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Refactor shardformer api (#4001)
* fix an error in readme * simplify code * refactor shardformer * add todo * remove slicer * resolve code reviewpull/4157/head
parent
611971248c
commit
d3bc530849
|
@ -1 +1 @@
|
|||
from .shard import ShardConfig, shard_model
|
||||
from .shard import ShardConfig, ShardFormer
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from .basepolicy import Policy
|
||||
|
||||
|
||||
def build_policies():
|
||||
r"""
|
||||
|
@ -41,47 +43,25 @@ def build_policies():
|
|||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
from .llama import LlamaPolicy
|
||||
auto_policy_dict[LlamaModel] = LlamaPolicy
|
||||
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
from .llama import LlamaForSequenceClassificationPolicy
|
||||
auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
|
||||
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
from .llama import LlamaForCausalLMPolicy
|
||||
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
||||
|
||||
from transformers import BertForMultipleChoice
|
||||
|
||||
from .bert import BertForMultipleChoicePolicy
|
||||
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
|
||||
|
||||
from transformers import GPT2Model
|
||||
|
||||
from .gpt2 import GPT2Policy
|
||||
auto_policy_dict[GPT2Model] = GPT2Policy
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
from .gpt2 import GPT2LMHeadModelPolicy
|
||||
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||
|
||||
from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
|
||||
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
|
||||
t5 = {
|
||||
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
|
||||
T5EncoderModel: T5EncoderModelPolicy,
|
||||
T5Model: T5ModelPolicy,
|
||||
}
|
||||
auto_policy_dict.update(t5)
|
||||
# from .llama import LlamaPolicy
|
||||
# auto_policy_dict[LlamaModel] = LlamaPolicy
|
||||
# from transformers import LlamaForSequenceClassification
|
||||
# from .llama import LlamaForSequenceClassificationPolicy
|
||||
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
|
||||
# from transformers import LlamaForCausalLM
|
||||
# from .llama import LlamaForCausalLMPolicy
|
||||
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
||||
# from transformers import GPT2Model
|
||||
# from .gpt2 import GPT2Policy
|
||||
# auto_policy_dict[GPT2Model] = GPT2Policy
|
||||
# from transformers import GPT2LMHeadModel
|
||||
# from .gpt2 import GPT2LMHeadModelPolicy
|
||||
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||
|
||||
return auto_policy_dict
|
||||
|
||||
|
||||
def get_autopolicy(model: nn.Module):
|
||||
def get_autopolicy(model: nn.Module) -> Policy:
|
||||
r"""
|
||||
Return the auto policy for the model
|
||||
|
||||
|
@ -97,7 +77,7 @@ def get_autopolicy(model: nn.Module):
|
|||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
||||
)
|
||||
return policy
|
||||
return policy()
|
||||
|
||||
|
||||
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
||||
|
|
|
@ -1,102 +1,65 @@
|
|||
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
|
||||
@dataclass
|
||||
class Argument:
|
||||
r"""
|
||||
The argument class for the policy
|
||||
|
||||
Args:
|
||||
attr_dict (Dict[str, Any]): The dict for the param setting
|
||||
param_funcs (:class:`List[Callable]`): The list for the param functions
|
||||
"""
|
||||
attr_dict: Dict[str, Any]
|
||||
param_funcs: List[Callable]
|
||||
class ParallelModule():
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Layer:
|
||||
class SubModuleReplacementDescription:
|
||||
r"""
|
||||
The layer object for the policy
|
||||
Describe how a submodule will be replaced
|
||||
|
||||
Args:
|
||||
suffix: (str): the suffix of the layer.
|
||||
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
||||
ignore (bool): Whether to ignore this layer if it is not in the model
|
||||
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
|
||||
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
|
||||
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
|
||||
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
|
||||
each device should have a part of Q, K and V weight.
|
||||
suffix (str): used to get the submodule object
|
||||
target_module (ParallelModule): specifies the module class used to replace to submodule
|
||||
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
|
||||
"""
|
||||
suffix: str = None
|
||||
replace_layer: Any = None
|
||||
ignore: bool = False
|
||||
reversed: bool = False
|
||||
n_cast: int = None
|
||||
suffix: str
|
||||
target_module: ParallelModule
|
||||
kwargs: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Col_Layer(Layer):
|
||||
class ModulePolicyDescription:
|
||||
r"""
|
||||
Class for col shard layer in tensor parrallel
|
||||
Describe how the attributes and parameters will be transformed in a policy
|
||||
|
||||
Args:
|
||||
weight (str): The weight suffix of the layer
|
||||
bias (str): The bias suffix of the layer
|
||||
gather_output (bool): Whether to gather the output of the layer
|
||||
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
|
||||
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
|
||||
must receive two arguments: module, process_group. One example is
|
||||
|
||||
```python
|
||||
def example_replace_weight(module: torch.nn.Module, process_group):
|
||||
weight = module.weight
|
||||
new_weight = shard_rowwise(weight, process_group)
|
||||
module.weight = torch.nn.Parameter(new_weight)
|
||||
```
|
||||
|
||||
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
|
||||
the module to be replaced and the target module used to replacement
|
||||
"""
|
||||
weight: str = None
|
||||
bias: str = None
|
||||
gather_output: bool = False
|
||||
attribute_replacement: Dict[str, Any]
|
||||
param_replacement: List[Callable]
|
||||
sub_module_replacement: List[SubModuleReplacementDescription]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Row_Layer(Layer):
|
||||
r"""
|
||||
Class for col shard layer in tensor parrallel
|
||||
|
||||
Args:
|
||||
weight (str): The weight suffix of the layer
|
||||
bias (str): The bias suffix of the layer
|
||||
"""
|
||||
weight: str = None
|
||||
bias: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dropout_Layer(Layer):
|
||||
r"""
|
||||
Class for dropout layer in tensor parrallel
|
||||
|
||||
Args:
|
||||
p (str): The dropout rate suffix of the layer
|
||||
"""
|
||||
p: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Embedding_Layer(Layer):
|
||||
r"""
|
||||
Class for col shard layer in tensor parrallel
|
||||
|
||||
Args:
|
||||
weight (str): The weight suffix of the layer
|
||||
"""
|
||||
weight: str = None
|
||||
gather_output: bool = True
|
||||
|
||||
|
||||
class Policy():
|
||||
class Policy(ABC):
|
||||
r"""
|
||||
The base class for all the policies
|
||||
|
||||
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||
or OPTPolicy for OPT model.
|
||||
|
||||
AutoPolicy:
|
||||
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
||||
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
||||
|
@ -111,137 +74,75 @@ class Policy():
|
|||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
|
||||
def set_model(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Set model as an attribute of the Policy object so that we can access the model's attributes.
|
||||
|
||||
Args:
|
||||
model (:class:`nn.Module`): The model to be perform
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
|
||||
r"""
|
||||
Perform some preprocessing of the model, like reshaping the embedding layer
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
r"""
|
||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||
argument for the modify layer
|
||||
|
||||
Args:
|
||||
model_config (:class:`tansformer.Config`): The config of transformer model
|
||||
world_size (int)): The world size of sharding model
|
||||
|
||||
Return:
|
||||
Dict for the modify policy,
|
||||
::
|
||||
{
|
||||
origin layer class1 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
argument1: value1,
|
||||
argument2: value2,
|
||||
origin layer class1 (nn.Module): ModulePolicyDescription(
|
||||
attribute_replacement = {
|
||||
"attribute1": value1,
|
||||
"attribute2": value2,
|
||||
...
|
||||
},
|
||||
param_funcs = [
|
||||
staticmethod1,
|
||||
staticmethod2,
|
||||
param_replacement = [
|
||||
function1,
|
||||
function2,
|
||||
...
|
||||
],
|
||||
sub_module_replacement = [
|
||||
`SubModuleReplacementDescription` description1,
|
||||
`SubModuleReplacementDescription` description2,
|
||||
...
|
||||
]
|
||||
),
|
||||
origin layer class2 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
argument1: value1,
|
||||
argument2: value2,
|
||||
...
|
||||
},
|
||||
param_funcs = [
|
||||
staticmethod1,
|
||||
staticmethod2,
|
||||
...
|
||||
]
|
||||
origin layer class2 (nn.Module): ModulePolicyDescription(
|
||||
...
|
||||
),
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
|
||||
@abstractmethod
|
||||
def new_model_class(self) -> Union[Type[nn.Module], None]:
|
||||
r"""
|
||||
Return the dict for the inject model
|
||||
Return the new model class for the new model, None means no need to modify the model class
|
||||
|
||||
Return:
|
||||
The injected model, key is the original model and value is the new shardmodel
|
||||
::
|
||||
(OrignModel, CustomModel)
|
||||
in `CustomModel`, we can overwrite the forward and backward process
|
||||
"""
|
||||
return None
|
||||
New model class
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Union[Dict[str, str], None]:
|
||||
E.g.
|
||||
```
|
||||
return BertModel_
|
||||
```
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self) -> nn.Module:
|
||||
r"""
|
||||
Return the dict for the binding model, None means no need to bind
|
||||
|
||||
Return:
|
||||
This method should return the binding relationship for some layers share the weight or bias,
|
||||
the key and value is the suffix of the weight or bias of the model
|
||||
::
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
Perform some postprocessing of the model, like binding the weight of embedding layer with
|
||||
the classifier layer
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> Union[List, None]:
|
||||
r"""
|
||||
Attention qkv layer
|
||||
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
|
||||
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
|
||||
in ``Layer`` object can refer to the ``Layer`` class.
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object, each layer is the new
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> Union[List, None]:
|
||||
r"""
|
||||
Attention output projection layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> Union[List, None]:
|
||||
r"""
|
||||
h -> 4h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> Union[List, None]:
|
||||
r"""
|
||||
4h -> h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> Union[List, None]:
|
||||
r"""
|
||||
Partially slice the embedding layer
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> Union[List, None]:
|
||||
r"""
|
||||
Partially slice the embedding layer, None means there is no unembedding layer
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return None
|
||||
|
|
|
@ -1,220 +1,77 @@
|
|||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from ..utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class ParallelModule():
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
def preprocess(self, shard_config: ShardConfig = None):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self, shard_config: ShardConfig = None):
|
||||
return {
|
||||
BertLayer:
|
||||
Argument(
|
||||
attr_dict={
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||
"attention.self.all_head_size":
|
||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size":
|
||||
self.model.config.hidden_size // shard_config.tensor_parallel_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"attention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
|
||||
BertEmbeddings:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.embedding,
|
||||
]),
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=ParallelModule,
|
||||
),
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="attention.self.query",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="attention.self.key",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="attention.self.value",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Dropout_Layer(
|
||||
suffix="attention.self.dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="crossattention.self.query",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="crossattention.self.key",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="crossattention.self.value",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_out():
|
||||
return [
|
||||
Row_Layer(
|
||||
suffix="attention.output.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Dropout_Layer(
|
||||
suffix="attention.output.dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="crossattention.output.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_in():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="intermediate.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_out():
|
||||
return [
|
||||
Row_Layer(
|
||||
suffix="output.dense",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Dropout_Layer(
|
||||
suffix="output.dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding():
|
||||
return [Col_Layer(
|
||||
suffix="word_embeddings",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def unembedding():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="decoder",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForPretraining
|
||||
class BertForPretrainingPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
def postprocess(self):
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = BertPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
|
@ -231,36 +88,5 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def inject_policy():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy():
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return BertPolicy.argument_policy(config, world_size)
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .shard_config import ShardConfig
|
||||
from .sharder import ModelSharder, shard_model
|
||||
from .slicer import Slicer
|
||||
from .sharder import ModelSharder
|
||||
from .shardformer import ShardFormer
|
||||
|
||||
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
|
||||
__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer']
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Literal
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
||||
|
@ -9,10 +10,18 @@ class ShardConfig:
|
|||
The config for sharding the huggingface model
|
||||
|
||||
Args:
|
||||
rank (int): The rank of local process
|
||||
world_size (int): The world size of the distributed process
|
||||
data_parallel_size (int): The size of data parallel
|
||||
tensor_parallel_size (int): The size of tensor parallel
|
||||
pipeline_parallel_size (int): The size of pipeline parallel
|
||||
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
|
||||
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
|
||||
will not calculate the loss and just return the output.
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
rank: int = None
|
||||
world_size: int = None
|
||||
data_parallel_size: int
|
||||
tensor_parallel_size: int
|
||||
|
||||
pipeline_parallel_size: int
|
||||
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
inference_only: bool = True
|
||||
gather_output: bool = True
|
||||
|
|
|
@ -4,11 +4,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
|
||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||
from ..policies.basepolicy import Policy
|
||||
from ..utils.utils import setattr_
|
||||
from .shard_config import ShardConfig
|
||||
from .slicer import Slicer
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
|
||||
|
@ -28,20 +29,23 @@ class ModelSharder(object):
|
|||
model: nn.Module,
|
||||
policy: Policy,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
) -> None:
|
||||
pg_manager: ProcessGroupManager = None) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.slicer = Slicer(shard_config)
|
||||
self.shard_config = shard_config
|
||||
self.model_config = self.model.config
|
||||
self.pg_manager = pg_manager
|
||||
|
||||
def shard(self) -> None:
|
||||
self.reshape_embedding()
|
||||
self.inject_model(self.model)
|
||||
self.replace_layer(self.model)
|
||||
self.bind_layer(self.model)
|
||||
r"""
|
||||
Shard the model according to the policy
|
||||
"""
|
||||
self.policy.set_model(self.model)
|
||||
self.preprocess()
|
||||
self.replace_model_class()
|
||||
self.replace_module()
|
||||
self.postprocess()
|
||||
|
||||
def reshape_embedding(self,) -> None:
|
||||
def reshape_embedding(self) -> None:
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
|
@ -52,10 +56,13 @@ class ModelSharder(object):
|
|||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def inject_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
def preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess(self.shard_config)
|
||||
|
||||
def postprocess(self) -> None:
|
||||
self.model = self.policy.postprocess()
|
||||
|
||||
def replace_model_class(self,) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
|
@ -64,49 +71,43 @@ class ModelSharder(object):
|
|||
::
|
||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
if inject_policy is None:
|
||||
new_model_class = self.policy.new_model_class()
|
||||
if new_model_class is None:
|
||||
return
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
org_model_cls = inject_policy[0]
|
||||
shard_model_cls = inject_policy[1]
|
||||
for key in new_model_class.__dict__.keys():
|
||||
if hasattr(self.model.__class__, key):
|
||||
setattr(
|
||||
self.model.__class__,
|
||||
key,
|
||||
getattr(new_model_class, key),
|
||||
)
|
||||
|
||||
if model.__class__ == org_model_cls:
|
||||
for key in shard_model_cls.__dict__.keys():
|
||||
if hasattr(model.__class__, key):
|
||||
setattr(
|
||||
model.__class__,
|
||||
key,
|
||||
getattr(shard_model_cls, key),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{model.__class__} is not implemented so far")
|
||||
|
||||
def replace_layer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
def replace_module(self,) -> None:
|
||||
r"""
|
||||
Replace the layer according to the policy, and replace the layer one by one
|
||||
Replace the module according to the policy, and replace the module one by one
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The layer to shard
|
||||
model (:class:`torch.nn.Module`): The model to shard
|
||||
"""
|
||||
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
|
||||
for argument_policy in argument_policies.items():
|
||||
origin_layer_cls = argument_policy[0]
|
||||
attr_dict = argument_policy[1].attr_dict
|
||||
param_funcs = argument_policy[1].param_funcs
|
||||
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
print(self.policy)
|
||||
module_descriptions = self.policy.module_policy(self.shard_config)
|
||||
print(f"*******{module_descriptions}")
|
||||
for module_description in module_descriptions.items():
|
||||
origin_layer_cls = module_description[0]
|
||||
attr_replacement = module_description[1].attribute_replacement
|
||||
param_replacement = module_description[1].param_replacement
|
||||
sub_module_replacement = module_description[1].sub_module_replacement
|
||||
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
||||
sub_module_replacement)
|
||||
|
||||
def traverse_replace_layer(
|
||||
def _recursive_replace_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
module: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
attr_dict: Dict[str, Any],
|
||||
param_funcs: List[Callable],
|
||||
attr_replacement: Dict[str, Any],
|
||||
param_replacement: List[Callable],
|
||||
sub_module_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Reverse the replace layer operation
|
||||
|
@ -114,21 +115,52 @@ class ModelSharder(object):
|
|||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
origin_cls (:class:`transformers.model`): The origin layer class
|
||||
attr_dict (Dict): The attribute dict to modify
|
||||
policy_cls (:class:`Policy`): The policy class
|
||||
attr_replacement (Dict): The attribute dict to modify
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in polic
|
||||
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
||||
"""
|
||||
if layer.__class__ == origin_cls:
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(layer, k, v, ignore=True)
|
||||
self.shard_one_layer(layer, param_funcs)
|
||||
for name, child in layer.named_children():
|
||||
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
return layer
|
||||
if module.__class__ == origin_cls:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
self._replace_param(module, param_replacement)
|
||||
self._replace_sub_module(module, sub_module_replacement)
|
||||
for name, child in module.named_children():
|
||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
|
||||
sub_module_replacement)
|
||||
|
||||
def shard_one_layer(
|
||||
def _replace_attr(
|
||||
self,
|
||||
module: nn.Module,
|
||||
attr_replacement: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the attribute of the layer
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
attr_replacement (Dict): The attribute dict to modify
|
||||
"""
|
||||
for k, v in attr_replacement.items():
|
||||
setattr_(module, k, v, ignore=True)
|
||||
|
||||
def _replace_param(
|
||||
self,
|
||||
module: nn.Module,
|
||||
param_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the parameter of the layer
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||
"""
|
||||
# TODO: support parameter shard
|
||||
pass
|
||||
|
||||
def _replace_sub_module(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
param_funcs: List[Callable],
|
||||
sub_module_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
|
@ -138,145 +170,14 @@ class ModelSharder(object):
|
|||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
|
||||
"""
|
||||
for func in param_funcs:
|
||||
policy_layers = func()
|
||||
for policy_layer in policy_layers:
|
||||
suffix = policy_layer.suffix
|
||||
replace_layer_cls = policy_layer.replace_layer
|
||||
ignore = policy_layer.ignore
|
||||
reversed = policy_layer.reversed
|
||||
n_cast = policy_layer.n_cast
|
||||
for description in sub_module_replacement:
|
||||
suffix = description.suffix
|
||||
target_module = description.target_module
|
||||
kwargs = description.kwargs
|
||||
|
||||
assert replace_layer_cls is not None, 'replace_layer should not be None'
|
||||
assert target_module is not None, 'target_module should not be None'
|
||||
|
||||
# create new object to replace the origin layer
|
||||
# Linear
|
||||
suffix_layer = getattr_(org_layer, suffix, ignore=True)
|
||||
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
||||
if suffix_layer is None and ignore:
|
||||
continue
|
||||
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
|
||||
weight = None
|
||||
bias = None
|
||||
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
||||
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
weight = getattr_(org_layer, weight_attr)
|
||||
else:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
||||
|
||||
if bias_attr is not None:
|
||||
if hasattr_(org_layer, bias_attr):
|
||||
bias = getattr_(org_layer, bias_attr)
|
||||
else:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
||||
|
||||
# set the sliced weight and bias to the new nn_col layer
|
||||
assert weight is not None or bias is not None
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
bias=False if bias is None else True)
|
||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "Embedding1D":
|
||||
gather_output = policy_layer.gather_output
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
||||
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
# self.set_param(replace_layer, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
|
||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
# dropout
|
||||
elif isinstance(policy_layer, Dropout_Layer):
|
||||
p_attr = suffix + '.' + policy_layer.p
|
||||
p = getattr_(org_layer, p_attr, ignore=True)
|
||||
replace_layer = replace_layer_cls(p)
|
||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
|
||||
|
||||
def set_param(self,
|
||||
layer: Any,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
layer_attr: str = "") -> None:
|
||||
r"""
|
||||
Reset the weight and bias of the layer object
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
weight (:class:`torch.Tensor`): The weight of the layer
|
||||
bias (:class:`torch.Tensor`): The bias of the layer
|
||||
"""
|
||||
assert weight is not None or bias is not None
|
||||
if weight is not None:
|
||||
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
|
||||
self.set_layer_size(layer, layer_attr, weight.shape)
|
||||
if bias is not None:
|
||||
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
|
||||
|
||||
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
|
||||
r"""
|
||||
Set the layer attribute
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
size (:class:`torch.Size`): The size of the tensor
|
||||
"""
|
||||
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
|
||||
attrs = ["out_features", "in_features"]
|
||||
for i, attr in enumerate(attrs):
|
||||
if hasattr_(layer, f"{layer_attr}.{attr}"):
|
||||
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
||||
|
||||
def bind_layer(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Bind the layer according to the binding policy
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The shard model
|
||||
"""
|
||||
binding_map = self.policy.binding_policy()
|
||||
if binding_map is None:
|
||||
return
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(model, k, param)
|
||||
setattr_(model, v, param)
|
||||
|
||||
|
||||
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
|
||||
r"""
|
||||
The function is used to shard the PyTorch model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Model`): the origin huggingface model
|
||||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
# TODO: init shard_config automatically
|
||||
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
||||
# TODO: integrate with new layer
|
||||
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
|
||||
replace_layer = None
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.cluster import DistCoordinator, ProcessGroupManager
|
||||
|
||||
from ..policies.basepolicy import Policy
|
||||
from .shard_config import ShardConfig
|
||||
from .sharder import ModelSharder
|
||||
|
||||
|
||||
class ShardFormer:
|
||||
"""
|
||||
Parallelize model based on the given config and policy
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.shardformer import ShardFormer, ShardConfig
|
||||
from transformers import BertForMaskedLM
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
tensor_parallel_mode='1d',
|
||||
inference_only=True,
|
||||
gather_output=True
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
model = shard_former.shard_model(org_model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, shard_config: ShardConfig):
|
||||
"""
|
||||
Do two things:
|
||||
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
|
||||
2. serve as a store for
|
||||
"""
|
||||
self.coordinator = DistCoordinator()
|
||||
self.shard_config = shard_config
|
||||
self.pg_manager = None
|
||||
|
||||
def init_distributed(self) -> ProcessGroupManager:
|
||||
"""
|
||||
Initialize the distributed process group according to the
|
||||
"""
|
||||
pg_manager = ProcessGroupManager()
|
||||
if (self.shard_config.tensor_parallel_mode == '1d'):
|
||||
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
|
||||
self.pg_manager = pg_manager
|
||||
return pg_manager
|
||||
|
||||
def shard_model(self, model: nn.Module, policy: Policy = None):
|
||||
r"""
|
||||
The function is used to shard the PyTorch model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Model`): the origin huggingface model
|
||||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
|
||||
sharder.shard()
|
||||
return model
|
||||
|
||||
def shard_dataset(self, dataset: Dataset):
|
||||
"""
|
||||
Shard dataset for DP
|
||||
"""
|
||||
pass
|
|
@ -1,163 +0,0 @@
|
|||
import torch
|
||||
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
|
||||
from .shard_config import ShardConfig
|
||||
|
||||
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}
|
||||
|
||||
|
||||
class Slicer():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shardconfig: ShardConfig #TODO
|
||||
) -> None:
|
||||
self.shardconfig = shardconfig
|
||||
|
||||
def slice_weight_bias(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
policy_layer_cls: Layer,
|
||||
n_cast: int = None,
|
||||
reversed: bool = False,
|
||||
):
|
||||
r"""
|
||||
Slice the weight and bias according to policy layer cls
|
||||
``Layer`` -> do nothing
|
||||
``Col_Layer`` -> slice the weight and bias along dim 1
|
||||
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
|
||||
|
||||
Args:
|
||||
weight (:class:`torch.nn.Module`): The weight of the layer
|
||||
bias: (:class:`torch.nn.Module`): The bias of the layer
|
||||
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
|
||||
"""
|
||||
if policy_layer_cls in [Layer, Dropout_Layer]:
|
||||
return weight, bias
|
||||
|
||||
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
|
||||
# print(weight.shape, dim)
|
||||
if policy_layer_cls == Col_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
bias = self.slice_tensor(bias, 0, True, n_cast)
|
||||
elif policy_layer_cls == Row_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
elif policy_layer_cls == Embedding_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
else:
|
||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||
if reversed:
|
||||
weight = weight.transpose(0, 1).contiguous()
|
||||
return weight, bias
|
||||
|
||||
def slice_tensor(
|
||||
self,
|
||||
tensor_in: torch.Tensor,
|
||||
dim: int,
|
||||
is_bias: bool,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice tensor according to the config
|
||||
|
||||
Args:
|
||||
tensor_in (:class:`torch.Tensor`): The tensor to slice
|
||||
dim (int): The dimension to slice
|
||||
is_bias (bool): Whether the tensor is bias
|
||||
"""
|
||||
if tensor_in is None:
|
||||
return None
|
||||
if not is_bias:
|
||||
return self.slice_2d(tensor_in, dim, n_cast)
|
||||
else:
|
||||
return self.slice_1d(tensor_in, n_cast)
|
||||
|
||||
def slice_2d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
dim: int,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the 2D tensor
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
dim (int): The dimension to slice
|
||||
"""
|
||||
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||
if dim == 0:
|
||||
return self.slice_row(tensor, n_cast)
|
||||
elif dim == 1:
|
||||
return self.slice_col(tensor, n_cast)
|
||||
|
||||
def slice_1d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the 1D tensor
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
|
||||
def slice_col(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=1).contiguous()
|
||||
|
||||
def slice_row(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
|
@ -0,0 +1 @@
|
|||
from .utils import getattr_, hasattr_, setattr_
|
Loading…
Reference in New Issue