[shardformer] Add dropout layer in shard model and refactor policy api (#3949)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage
pull/4157/head
FoolPlayer 2023-06-12 16:52:18 +08:00 committed by Frank Lee
parent a73130482d
commit 45927d5527
7 changed files with 266 additions and 197 deletions

View File

@ -55,7 +55,7 @@ colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py
## 💡 Policy
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model.
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py).
You should do:
@ -68,7 +68,7 @@ You should do:
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
- The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.
@ -123,7 +123,7 @@ class CustomPolicy(Policy):
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
r"""
Return the dict for the inject model
@ -133,12 +133,12 @@ class CustomPolicy(Policy):
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
return None
@staticmethod
def binding_policy() -> Dict:
def binding_policy() -> Union[Dict[str, str], None]:
r"""
Return the dict for the binding model
Return the dict for the binding model, None means no need to bind
Return:
This method should return the binding relationship for some layers share the weight or bias,
@ -148,69 +148,70 @@ class CustomPolicy(Policy):
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
return None
@staticmethod
def attn_in() -> List:
"""
def attn_in() -> Union[List, None]:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
return None
@staticmethod
def attn_out() -> List:
"""
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def mlp_in() -> List:
"""
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def mlp_out() -> List:
"""
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def embedding() -> List:
"""
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def unembedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
def unembedding() -> Union[List, None]:
r"""
Partially slice the embedding layer, None means there is no unembedding layer
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
```
@ -232,21 +233,26 @@ class CustomPolicy(Policy):
- CLASS `Layer`:
Parameters:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- suffix: (str): the suffix of the layer to indicate the attribute of the layer.
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
- ignore (bool): Whether to ignore this layer if it is not in the model
- reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed.
- n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight.
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer.
CLASS `Col_Layer(Layer)`:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
This class inherited from `Layer`, representing the layer will be sliced along column.
This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists.
CLASS `Row_Layer(Layer)`:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
This class inherited from `Layer`, representing the layer will be sliced along row.
This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row.
- CLASS `Policy`:
@ -254,29 +260,37 @@ class CustomPolicy(Policy):
- `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`......
These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.
- `Policy.argument_policy()`
In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.
- `Policy.inject_policy()`
This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.
- `Policy.binding_policy()`
This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.
- CLASS `ModelSharder(model, policy)`:
This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
- `ModelShard.inject_model()`
This function is used to inject the model to modify the forward and backward progress.
- `ModelShard.replace_layer()`
This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.
- `ModelShard.bind_layer()`
This function is used to help different layers share weight or bias.
- CLASS `Slicer`:
This class is used to slice tensor according to policy.

View File

@ -1,7 +1,7 @@
# part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Any, Callable, Dict, List, Tuple, Union
import torch.nn as nn
@ -25,8 +25,7 @@ class Layer:
The layer object for the policy
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
suffix: (str): the suffix of the layer.
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
@ -35,8 +34,7 @@ class Layer:
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight: str = None
bias: str = None
suffix: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
@ -46,20 +44,40 @@ class Layer:
@dataclass
class Col_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
gather_output (bool): Whether to gather the output of the layer
"""
weight: str = None
bias: str = None
gather_output: bool = False
@dataclass
class Row_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
"""
pass
weight: str = None
bias: str = None
@dataclass
class Dropout_Layer(Layer):
r"""
Class for dropout layer in tensor parrallel
Args:
p (str): The dropout rate suffix of the layer
"""
p: str = None
class Policy():
@ -82,14 +100,14 @@ class Policy():
"""
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args:
model_config (:class:`tansformer.Config`): The config of transformer model
shard_config (:class:`ShardConfig`): The config for sharding model
world_size (int)): The world size of sharding model
Return:
Dict for the modify policy,
@ -126,7 +144,7 @@ class Policy():
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
r"""
Return the dict for the inject model
@ -139,9 +157,9 @@ class Policy():
return None
@staticmethod
def binding_policy() -> Dict:
def binding_policy() -> Union[Dict[str, str], None]:
r"""
Return the dict for the binding model
Return the dict for the binding model, None means no need to bind
Return:
This method should return the binding relationship for some layers share the weight or bias,
@ -154,7 +172,7 @@ class Policy():
return None
@staticmethod
def attn_in() -> List:
def attn_in() -> Union[List, None]:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
@ -164,50 +182,40 @@ class Policy():
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
return None
@staticmethod
def attn_out() -> List:
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def mlp_in() -> List:
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def mlp_out() -> List:
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
@staticmethod
def embedding() -> List:
r"""
Partially slice the embedding layer
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding() -> List:
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
@ -215,3 +223,13 @@ class Policy():
List[Layer]: List of layer object
"""
return None
@staticmethod
def unembedding() -> Union[List, None]:
r"""
Partially slice the embedding layer, None means there is no unembedding layer
Return:
List[Layer]: List of layer object
"""
return None

View File

@ -5,7 +5,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer
class BertPolicy(Policy):
@ -28,123 +28,126 @@ class BertPolicy(Policy):
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
},
param_funcs=[
BertPolicy.embedding,
]),
BertLMPredictionHead:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs=[
BertPolicy.unembedding,
])
}
@staticmethod
def binding_policy() -> Dict:
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod
def attn_in() -> List:
def attn_in():
return [
Col_Layer(
weight="attention.self.query.weight",
bias="attention.self.query.bias",
suffix="attention.self.query",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.key.weight",
bias="attention.self.key.bias",
suffix="attention.self.key",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.value.weight",
bias="attention.self.value.bias",
suffix="attention.self.value",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Dropout_Layer(
suffix="attention.self.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
),
Col_Layer(
weight="crossattention.self.query.weight",
bias="crossattention.self.query.bias",
suffix="crossattention.self.query",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.key.weight",
bias="crossattention.self.key.bias",
suffix="crossattention.self.key",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.value.weight",
bias="crossattention.self.value.bias",
suffix="crossattention.self.value",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
def attn_out() -> List:
def attn_out():
return [
Row_Layer(
weight="attention.output.dense.weight",
bias="attention.output.dense.bias",
suffix="attention.output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
Dropout_Layer(
suffix="attention.output.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
),
Row_Layer(
weight="crossattention.output.dense.weight",
bias="crossattention.output.dense.bias",
suffix="crossattention.output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
ignore=True,
),
]
@staticmethod
def mlp_in() -> List:
def mlp_in():
return [
Col_Layer(
weight="intermediate.dense.weight",
bias="intermediate.dense.bias",
suffix="intermediate.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
]
@staticmethod
def mlp_out() -> List:
def mlp_out():
return [
Row_Layer(
weight="output.dense.weight",
bias="output.dense.bias",
suffix="output.dense",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def embedding() -> List:
return [Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
@staticmethod
def unembedding() -> List:
return [
Col_Layer(
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
Dropout_Layer(
suffix="output.dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)
]
@staticmethod
def embedding():
return [Col_Layer(
suffix="word_embeddings",
weight="weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
from transformers import BertForMaskedLM
@ -154,18 +157,36 @@ from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertForMaskedLMPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
# return (BertForMaskedLM, BertForMaskedLM_)
return None
@staticmethod
def unembedding():
return [
Col_Layer(
suffix="decoder",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Dict:
return {}
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# _ = BertForMaskedLMPolicy(model)
# print(isinstance(model,list(_.inject_policy().keys())[0]))
def inject_policy():
return None

View File

@ -40,19 +40,22 @@ class GPT2Policy(Policy):
@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
Col_Layer(suffix="attn.c_attn",
weight="weight",
bias="bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
Col_Layer(suffix="crossattention.c_attn",
weight="weight",
bias="bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
Col_Layer(suffix="crossattention.q_attn",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
@ -61,12 +64,14 @@ class GPT2Policy(Policy):
@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
Row_Layer(suffix="attn.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
Row_Layer(suffix="crossattention.c_proj",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
@ -75,21 +80,23 @@ class GPT2Policy(Policy):
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True,
replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
Row_Layer(suffix="mlp.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
@ -111,8 +118,9 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
Col_Layer(suffix="lm_head",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
from ..utils.utils import getattr_, hasattr_, setattr_
from .shard_config import ShardConfig
from .slicer import Slicer
@ -141,65 +141,73 @@ class ModelSharder(object):
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
weight = None
bias = None
weight_attr = policy_layer.weight
bias_attr = policy_layer.bias
suffix = policy_layer.suffix
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output and self.shard_config.gather_output
n_cast = policy_layer.n_cast
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
weight = getattr_(org_layer, weight_attr)
elif not ignore:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
if bias_attr is not None:
if hasattr_(org_layer, bias_attr):
bias = getattr_(org_layer, bias_attr)
elif not ignore:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
# dont have the attribute in policy, and ignore is true
if weight is None and bias is None and ignore:
continue
# set the sliced weight and bias to the new nn_col layer
assert weight is not None or bias is not None
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
assert replace_layer_cls is not None, 'replace_layer should not be None'
# create new object to replace the origin layer
if replace_layer_cls is not None:
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
# Linear
suffix_layer = getattr_(org_layer, suffix, ignore=True)
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
if suffix_layer is None and ignore:
continue
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
weight = None
bias = None
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
weight = getattr_(org_layer, weight_attr)
else:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
if bias_attr is not None:
if hasattr_(org_layer, bias_attr):
bias = getattr_(org_layer, bias_attr)
else:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
# set the sliced weight and bias to the new nn_col layer
assert weight is not None or bias is not None
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
gather_output = policy_layer.gather_output and self.shard_config.gather_output
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
# self.set_param(replace_layer, weight, bias)
else:
raise NotImplementedError(
f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
# do not replace the layer object, just replace the weight and bias
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
# dropout
elif isinstance(policy_layer, Dropout_Layer):
p_attr = suffix + '.' + policy_layer.p
p = getattr_(org_layer, p_attr, ignore=True)
replace_layer = replace_layer_cls(p)
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
else:
self.set_param(org_layer, layer_attr, weight, bias)
raise NotImplementedError(
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
def set_param(self,
layer: Any,

View File

@ -1,6 +1,6 @@
import torch
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer
from .shard_config import ShardConfig
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
@ -33,7 +33,7 @@ class Slicer():
bias: (:class:`torch.nn.Module`): The bias of the layer
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
"""
if policy_layer_cls == Layer:
if policy_layer_cls in [Layer, Dropout_Layer]:
return weight, bias
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])

View File

@ -37,7 +37,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
setattr(obj, attrs[-1], value)
def getattr_(obj, attr: str, ignore: bool = None):
def getattr_(obj, attr: str, ignore: bool = False):
r"""
Get the object's multi sublevel attr