2023-05-22 07:02:17 +00:00
|
|
|
# part of code modified from https://github.com/tunib-ai/parallelformers
|
|
|
|
|
2023-06-07 08:09:40 +00:00
|
|
|
from dataclasses import dataclass
|
2023-06-12 08:52:18 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Tuple, Union
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
import torch.nn as nn
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Argument:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
|
|
|
The argument class for the policy
|
|
|
|
|
|
|
|
Args:
|
|
|
|
attr_dict (Dict[str, Any]): The dict for the param setting
|
|
|
|
param_funcs (:class:`List[Callable]`): The list for the param functions
|
|
|
|
"""
|
|
|
|
attr_dict: Dict[str, Any]
|
|
|
|
param_funcs: List[Callable]
|
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Layer:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
The layer object for the policy
|
|
|
|
|
|
|
|
Args:
|
2023-06-12 08:52:18 +00:00
|
|
|
suffix: (str): the suffix of the layer.
|
2023-05-24 02:26:46 +00:00
|
|
|
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
|
|
|
ignore (bool): Whether to ignore this layer if it is not in the model
|
2023-06-07 08:09:40 +00:00
|
|
|
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
|
|
|
|
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
|
|
|
|
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
|
|
|
|
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
|
|
|
|
each device should have a part of Q, K and V weight.
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
suffix: str = None
|
2023-05-22 07:02:17 +00:00
|
|
|
replace_layer: Any = None
|
|
|
|
ignore: bool = False
|
2023-06-07 08:09:40 +00:00
|
|
|
reversed: bool = False
|
|
|
|
n_cast: int = None
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Col_Layer(Layer):
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-12 08:52:18 +00:00
|
|
|
Class for col shard layer in tensor parrallel
|
2023-05-24 02:26:46 +00:00
|
|
|
|
|
|
|
Args:
|
2023-06-12 08:52:18 +00:00
|
|
|
weight (str): The weight suffix of the layer
|
|
|
|
bias (str): The bias suffix of the layer
|
2023-05-24 02:26:46 +00:00
|
|
|
gather_output (bool): Whether to gather the output of the layer
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
weight: str = None
|
|
|
|
bias: str = None
|
2023-05-22 07:02:17 +00:00
|
|
|
gather_output: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Row_Layer(Layer):
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-12 08:52:18 +00:00
|
|
|
Class for col shard layer in tensor parrallel
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weight (str): The weight suffix of the layer
|
|
|
|
bias (str): The bias suffix of the layer
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
weight: str = None
|
|
|
|
bias: str = None
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Dropout_Layer(Layer):
|
|
|
|
r"""
|
|
|
|
Class for dropout layer in tensor parrallel
|
|
|
|
|
|
|
|
Args:
|
|
|
|
p (str): The dropout rate suffix of the layer
|
|
|
|
"""
|
|
|
|
p: str = None
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Policy():
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
The base class for all the policies
|
2023-05-24 02:26:46 +00:00
|
|
|
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
|
|
|
or OPTPolicy for OPT model.
|
2023-05-22 07:02:17 +00:00
|
|
|
AutoPolicy:
|
2023-05-24 02:26:46 +00:00
|
|
|
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
2023-05-22 07:02:17 +00:00
|
|
|
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
2023-05-24 02:26:46 +00:00
|
|
|
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
|
2023-05-22 07:02:17 +00:00
|
|
|
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
|
2023-05-24 02:26:46 +00:00
|
|
|
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
|
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
CustomPolicy:
|
2023-05-24 02:26:46 +00:00
|
|
|
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
|
|
|
|
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
|
|
|
|
class for the example.
|
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
|
|
|
Return the dict for the modify policy, the key is the original layer class and the value is the
|
|
|
|
argument for the modify layer
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
model_config (:class:`tansformer.Config`): The config of transformer model
|
2023-06-12 08:52:18 +00:00
|
|
|
world_size (int)): The world size of sharding model
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
Return:
|
|
|
|
Dict for the modify policy,
|
2023-05-24 02:26:46 +00:00
|
|
|
::
|
2023-05-22 07:02:17 +00:00
|
|
|
{
|
|
|
|
origin layer class1 (nn.Module): Argument(
|
|
|
|
attr_dict = {
|
|
|
|
argument1: value1,
|
|
|
|
argument2: value2,
|
|
|
|
...
|
|
|
|
},
|
|
|
|
param_funcs = [
|
|
|
|
staticmethod1,
|
|
|
|
staticmethod2,
|
|
|
|
...
|
|
|
|
]
|
|
|
|
),
|
|
|
|
origin layer class2 (nn.Module): Argument(
|
|
|
|
attr_dict = {
|
|
|
|
argument1: value1,
|
|
|
|
argument2: value2,
|
|
|
|
...
|
|
|
|
},
|
|
|
|
param_funcs = [
|
|
|
|
staticmethod1,
|
|
|
|
staticmethod2,
|
|
|
|
...
|
|
|
|
]
|
|
|
|
),
|
|
|
|
...
|
|
|
|
}
|
|
|
|
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
|
|
|
Return the dict for the inject model
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
The injected model, key is the original model and value is the new shardmodel
|
2023-05-24 02:26:46 +00:00
|
|
|
::
|
|
|
|
(OrignModel, CustomModel)
|
|
|
|
in `CustomModel`, we can overwrite the forward and backward process
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-07 08:09:40 +00:00
|
|
|
return None
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def binding_policy() -> Union[Dict[str, str], None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-12 08:52:18 +00:00
|
|
|
Return the dict for the binding model, None means no need to bind
|
2023-05-24 02:26:46 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
This method should return the binding relationship for some layers share the weight or bias,
|
|
|
|
the key and value is the suffix of the weight or bias of the model
|
|
|
|
::
|
|
|
|
return {
|
|
|
|
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
|
|
|
}
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-07 08:09:40 +00:00
|
|
|
return None
|
2023-05-24 02:26:46 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def attn_in() -> Union[List, None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Attention qkv layer
|
2023-05-24 02:26:46 +00:00
|
|
|
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
|
|
|
|
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
|
|
|
|
in ``Layer`` object can refer to the ``Layer`` class.
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Returns:
|
2023-05-24 02:26:46 +00:00
|
|
|
List[Layer]: List of layer object, each layer is the new
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
return None
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def attn_out() -> Union[List, None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Attention output projection layer
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[Layer]: List of layer object
|
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
return None
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def mlp_in() -> Union[List, None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
h -> 4h mlp layer
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[Layer]: List of layer object
|
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
return None
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def mlp_out() -> Union[List, None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
4h -> h mlp layer
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[Layer]: List of layer object
|
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
return None
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def embedding() -> Union[List, None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Partially slice the embedding layer
|
|
|
|
|
|
|
|
Return:
|
|
|
|
List[Layer]: List of layer object
|
|
|
|
"""
|
2023-06-12 08:52:18 +00:00
|
|
|
return None
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
@staticmethod
|
2023-06-12 08:52:18 +00:00
|
|
|
def unembedding() -> Union[List, None]:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-12 08:52:18 +00:00
|
|
|
Partially slice the embedding layer, None means there is no unembedding layer
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
List[Layer]: List of layer object
|
|
|
|
"""
|
2023-06-07 08:09:40 +00:00
|
|
|
return None
|