From 45927d552746c1e02c9a860ca5375a5eb3facda4 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 12 Jun 2023 16:52:18 +0800 Subject: [PATCH] [shardformer] Add dropout layer in shard model and refactor policy api (#3949) * add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage --- colossalai/shardformer/README.md | 80 ++++++---- colossalai/shardformer/policies/basepolicy.py | 82 ++++++---- colossalai/shardformer/policies/bert.py | 147 ++++++++++-------- colossalai/shardformer/policies/gpt2.py | 40 +++-- colossalai/shardformer/shard/sharder.py | 108 +++++++------ colossalai/shardformer/shard/slicer.py | 4 +- colossalai/shardformer/utils/utils.py | 2 +- 7 files changed, 266 insertions(+), 197 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 222626db3..b8357c203 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -55,7 +55,7 @@ colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py ## 💡 Policy -If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. +If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py). You should do: @@ -68,7 +68,7 @@ You should do: - Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. 4. Overwrite or add the param functions - These functions use a suffix to record the path of weight or bias for the layer. - - The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively. + - The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details. 5. Overwrite `binding_policy` (Optional) - Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers. - This function will return a dict, the key and value are the suffix of weight need to be binded. @@ -123,7 +123,7 @@ class CustomPolicy(Policy): raise NotImplementedError @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: r""" Return the dict for the inject model @@ -133,12 +133,12 @@ class CustomPolicy(Policy): (OrignModel, CustomModel) in `CustomModel`, we can overwrite the forward and backward process """ - return () + return None @staticmethod - def binding_policy() -> Dict: + def binding_policy() -> Union[Dict[str, str], None]: r""" - Return the dict for the binding model + Return the dict for the binding model, None means no need to bind Return: This method should return the binding relationship for some layers share the weight or bias, @@ -148,69 +148,70 @@ class CustomPolicy(Policy): "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } """ - return NotImplementedError + return None @staticmethod - def attn_in() -> List: - """ + def attn_in() -> Union[List, None]: + r""" Attention qkv layer + In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be + ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters + in ``Layer`` object can refer to the ``Layer`` class. Returns: List[Layer]: List of layer object, each layer is the new """ - return NotImplementedError + return None @staticmethod - def attn_out() -> List: - """ + def attn_out() -> Union[List, None]: + r""" Attention output projection layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_in() -> List: - """ + def mlp_in() -> Union[List, None]: + r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_out() -> List: - """ + def mlp_out() -> Union[List, None]: + r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def embedding() -> List: - """ + def embedding() -> Union[List, None]: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def unembedding() -> List: - """ - Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums + def unembedding() -> Union[List, None]: + r""" + Partially slice the embedding layer, None means there is no unembedding layer Return: List[Layer]: List of layer object """ - return NotImplementedError + return None ``` @@ -232,21 +233,26 @@ class CustomPolicy(Policy): - CLASS `Layer`: Parameters: - - weight (str): The weight suffix of the layer - - bias (str): The bias suffix of the layer + - suffix: (str): the suffix of the layer to indicate the attribute of the layer. - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - ignore (bool): Whether to ignore this layer if it is not in the model + - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed. + - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight. - This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. + This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer. CLASS `Col_Layer(Layer)`: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. - This class inherited from `Layer`, representing the layer will be sliced along column. + This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists. CLASS `Row_Layer(Layer)`: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer - This class inherited from `Layer`, representing the layer will be sliced along row. + This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row. - CLASS `Policy`: @@ -254,29 +260,37 @@ class CustomPolicy(Policy): - `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. + - `Policy.argument_policy()` In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. + - `Policy.inject_policy()` This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. + - `Policy.binding_policy()` This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters. + - CLASS `ModelSharder(model, policy)`: This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model. + - `ModelShard.inject_model()` This function is used to inject the model to modify the forward and backward progress. + - `ModelShard.replace_layer()` This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication. + - `ModelShard.bind_layer()` This function is used to help different layers share weight or bias. + - CLASS `Slicer`: This class is used to slice tensor according to policy. diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 644d115a2..d55df59fd 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,7 +1,7 @@ # part of code modified from https://github.com/tunib-ai/parallelformers from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Any, Callable, Dict, List, Tuple, Union import torch.nn as nn @@ -25,8 +25,7 @@ class Layer: The layer object for the policy Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer + suffix: (str): the suffix of the layer. replace_layer (:class:`colosalai.nn`): The layer to replace the original layer ignore (bool): Whether to ignore this layer if it is not in the model reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], @@ -35,8 +34,7 @@ class Layer: but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and each device should have a part of Q, K and V weight. """ - weight: str = None - bias: str = None + suffix: str = None replace_layer: Any = None ignore: bool = False reversed: bool = False @@ -46,20 +44,40 @@ class Layer: @dataclass class Col_Layer(Layer): r""" - Class for col shard layer in MegatronLM + Class for col shard layer in tensor parrallel Args: + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer gather_output (bool): Whether to gather the output of the layer """ + weight: str = None + bias: str = None gather_output: bool = False @dataclass class Row_Layer(Layer): r""" - Class for col shard layer in MegatronLM + Class for col shard layer in tensor parrallel + + Args: + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer """ - pass + weight: str = None + bias: str = None + + +@dataclass +class Dropout_Layer(Layer): + r""" + Class for dropout layer in tensor parrallel + + Args: + p (str): The dropout rate suffix of the layer + """ + p: str = None class Policy(): @@ -82,14 +100,14 @@ class Policy(): """ @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]: r""" Return the dict for the modify policy, the key is the original layer class and the value is the argument for the modify layer Args: model_config (:class:`tansformer.Config`): The config of transformer model - shard_config (:class:`ShardConfig`): The config for sharding model + world_size (int)): The world size of sharding model Return: Dict for the modify policy, @@ -126,7 +144,7 @@ class Policy(): raise NotImplementedError @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: r""" Return the dict for the inject model @@ -139,9 +157,9 @@ class Policy(): return None @staticmethod - def binding_policy() -> Dict: + def binding_policy() -> Union[Dict[str, str], None]: r""" - Return the dict for the binding model + Return the dict for the binding model, None means no need to bind Return: This method should return the binding relationship for some layers share the weight or bias, @@ -154,7 +172,7 @@ class Policy(): return None @staticmethod - def attn_in() -> List: + def attn_in() -> Union[List, None]: r""" Attention qkv layer In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be @@ -164,50 +182,40 @@ class Policy(): Returns: List[Layer]: List of layer object, each layer is the new """ - return NotImplementedError + return None @staticmethod - def attn_out() -> List: + def attn_out() -> Union[List, None]: r""" Attention output projection layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_in() -> List: + def mlp_in() -> Union[List, None]: r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_out() -> List: + def mlp_out() -> Union[List, None]: r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def embedding() -> List: - r""" - Partially slice the embedding layer - - Return: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def unembedding() -> List: + def embedding() -> Union[List, None]: r""" Partially slice the embedding layer @@ -215,3 +223,13 @@ class Policy(): List[Layer]: List of layer object """ return None + + @staticmethod + def unembedding() -> Union[List, None]: + r""" + Partially slice the embedding layer, None means there is no unembedding layer + + Return: + List[Layer]: List of layer object + """ + return None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5d489f419..67e910d52 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -5,7 +5,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be import colossalai.shardformer.layer.layers as col_nn -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer +from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer class BertPolicy(Policy): @@ -28,123 +28,126 @@ class BertPolicy(Policy): Argument( attr_dict={ # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, }, param_funcs=[ BertPolicy.embedding, ]), - BertLMPredictionHead: - Argument( - attr_dict={ - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs=[ - BertPolicy.unembedding, - ]) } @staticmethod - def binding_policy() -> Dict: + def binding_policy(): return { "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } @staticmethod - def attn_in() -> List: + def attn_in(): return [ Col_Layer( - weight="attention.self.query.weight", - bias="attention.self.query.bias", + suffix="attention.self.query", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Col_Layer( - weight="attention.self.key.weight", - bias="attention.self.key.bias", + suffix="attention.self.key", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Col_Layer( - weight="attention.self.value.weight", - bias="attention.self.value.bias", + suffix="attention.self.value", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), + Dropout_Layer( + suffix="attention.self.dropout", + p="p", + replace_layer=col_nn.Dropout1D, + ), Col_Layer( - weight="crossattention.self.query.weight", - bias="crossattention.self.query.bias", + suffix="crossattention.self.query", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), Col_Layer( - weight="crossattention.self.key.weight", - bias="crossattention.self.key.bias", + suffix="crossattention.self.key", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), Col_Layer( - weight="crossattention.self.value.weight", - bias="crossattention.self.value.bias", + suffix="crossattention.self.value", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), ] @staticmethod - def attn_out() -> List: + def attn_out(): return [ Row_Layer( - weight="attention.output.dense.weight", - bias="attention.output.dense.bias", + suffix="attention.output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ), + Dropout_Layer( + suffix="attention.output.dropout", + p="p", + replace_layer=col_nn.Dropout1D, + ), Row_Layer( - weight="crossattention.output.dense.weight", - bias="crossattention.output.dense.bias", + suffix="crossattention.output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ignore=True, ), ] @staticmethod - def mlp_in() -> List: + def mlp_in(): return [ Col_Layer( - weight="intermediate.dense.weight", - bias="intermediate.dense.bias", + suffix="intermediate.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), ] @staticmethod - def mlp_out() -> List: + def mlp_out(): return [ Row_Layer( - weight="output.dense.weight", - bias="output.dense.bias", + suffix="output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ), - ] - - @staticmethod - def embedding() -> List: - return [Col_Layer( - weight="word_embeddings.weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - )] - - @staticmethod - def unembedding() -> List: - return [ - Col_Layer( - weight="decoder.weight", - bias="decoder.bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, + Dropout_Layer( + suffix="output.dropout", + p="p", + replace_layer=col_nn.Dropout1D, ) ] + @staticmethod + def embedding(): + return [Col_Layer( + suffix="word_embeddings", + weight="weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + )] + from transformers import BertForMaskedLM @@ -154,18 +157,36 @@ from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ class BertForMaskedLMPolicy(BertPolicy): @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertForMaskedLMPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def inject_policy(): # return (BertForMaskedLM, BertForMaskedLM_) return None + @staticmethod + def unembedding(): + return [ + Col_Layer( + suffix="decoder", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + class BertForSequenceClassificationPolicy(BertPolicy): @staticmethod - def inject_policy() -> Dict: - return {} - - -# model = BertForMaskedLM.from_pretrained("bert-base-uncased") -# _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) + def inject_policy(): + return None diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 44dc9c72f..0d4342e75 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -40,19 +40,22 @@ class GPT2Policy(Policy): @staticmethod def attn_in() -> List: return [ - Col_Layer(weight="attn.c_attn.weight", - bias="attn.c_attn.bias", + Col_Layer(suffix="attn.c_attn", + weight="weight", + bias="bias", n_cast=3, reversed=True, replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.c_attn.weight", - bias="crossattention.c_attn.bias", + Col_Layer(suffix="crossattention.c_attn", + weight="weight", + bias="bias", n_cast=2, reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.q_attn.weight", - bias="crossattention.q_attn.bias", + Col_Layer(suffix="crossattention.q_attn", + weight="weight", + bias="bias", reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Col) @@ -61,12 +64,14 @@ class GPT2Policy(Policy): @staticmethod def attn_out() -> List: return [ - Row_Layer(weight="attn.c_proj.weight", - bias="attn.c_proj.bias", + Row_Layer(suffix="attn.c_proj", + weight="weight", + bias="bias", reversed=True, replace_layer=col_nn.Linear1D_Row), - Row_Layer(weight="crossattention.c_proj.weight", - bias="crossattention.c_proj.bias", + Row_Layer(suffix="crossattention.c_proj", + weight="weight", + bias="bias", reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Row) @@ -75,21 +80,23 @@ class GPT2Policy(Policy): @staticmethod def mlp_in() -> List: return [ - Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col), + Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True, + replace_layer=col_nn.Linear1D_Col), ] @staticmethod def mlp_out() -> List: return [ - Row_Layer(weight="mlp.c_proj.weight", - bias="mlp.c_proj.bias", + Row_Layer(suffix="mlp.c_proj", + weight="weight", + bias="bias", reversed=True, replace_layer=col_nn.Linear1D_Row) ] @staticmethod def embedding() -> List: - return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)] + return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)] from transformers import GPT2LMHeadModel @@ -111,8 +118,9 @@ class GPT2LMHeadModelPolicy(GPT2Policy): @staticmethod def unembedding() -> List: return [ - Col_Layer(weight="lm_head.weight", - bias="lm_head.bias", + Col_Layer(suffix="lm_head", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, gather_output=True) ] diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 159bebccd..95184cfe6 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -5,7 +5,7 @@ import torch.nn as nn from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer from ..utils.utils import getattr_, hasattr_, setattr_ from .shard_config import ShardConfig from .slicer import Slicer @@ -141,65 +141,73 @@ class ModelSharder(object): for func in param_funcs: policy_layers = func() for policy_layer in policy_layers: - weight = None - bias = None - weight_attr = policy_layer.weight - bias_attr = policy_layer.bias + suffix = policy_layer.suffix replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore - n_cast = policy_layer.n_cast reversed = policy_layer.reversed - if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output and self.shard_config.gather_output + n_cast = policy_layer.n_cast - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") - - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # dont have the attribute in policy, and ignore is true - if weight is None and bias is None and ignore: - continue - - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) - - # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) + assert replace_layer_cls is not None, 'replace_layer should not be None' # create new object to replace the origin layer - if replace_layer_cls is not None: - if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)): - if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], - weight.shape[0], - bias=False if bias is None else True) - elif replace_layer_cls.__name__ == "Linear1D_Col": - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - bias=False if bias is None else True, - gather_output=gather_output) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + # Linear + suffix_layer = getattr_(org_layer, suffix, ignore=True) + assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" + if suffix_layer is None and ignore: + continue + if isinstance(policy_layer, (Col_Layer, Row_Layer)): + weight = None + bias = None + weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None + bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + else: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + else: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + + # slice weight and bias + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) + + if replace_layer_cls.__name__ == "Linear1D_Row": + replace_layer = replace_layer_cls(weight.shape[1], + weight.shape[0], + bias=False if bias is None else True) + elif replace_layer_cls.__name__ == "Linear1D_Col": + gather_output = policy_layer.gather_output and self.shard_config.gather_output + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + bias=False if bias is None else True, + gather_output=gather_output) + elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], - getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) + getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) + # setattr_(org_layer, suffix, replace_layer, ignore=ignore) + # self.set_param(replace_layer, weight, bias) else: raise NotImplementedError( - f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") - # do not replace the layer object, just replace the weight and bias + f"Replacing to {replace_layer_cls.__name__} is not implemented so far") + setattr_(org_layer, suffix, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + # dropout + elif isinstance(policy_layer, Dropout_Layer): + p_attr = suffix + '.' + policy_layer.p + p = getattr_(org_layer, p_attr, ignore=True) + replace_layer = replace_layer_cls(p) + setattr_(org_layer, suffix, replace_layer, ignore=ignore) else: - self.set_param(org_layer, layer_attr, weight, bias) + raise NotImplementedError( + f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far") def set_param(self, layer: Any, diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 09e3219f8..0bf8f58b8 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,6 +1,6 @@ import torch -from ..policies.basepolicy import Col_Layer, Layer, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer from .shard_config import ShardConfig dim_mapping = {Col_Layer: 0, Row_Layer: 1} @@ -33,7 +33,7 @@ class Slicer(): bias: (:class:`torch.nn.Module`): The bias of the layer policy_layer_class (:class:`Policy`): The class represent how to slice the tensor """ - if policy_layer_cls == Layer: + if policy_layer_cls in [Layer, Dropout_Layer]: return weight, bias dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index eb84edd88..2c02b6f69 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -37,7 +37,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str, ignore: bool = None): +def getattr_(obj, attr: str, ignore: bool = False): r""" Get the object's multi sublevel attr