diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 394334558..c5e33fd49 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist + from colossalai.core import global_context as gpc try: @@ -72,6 +73,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): total_input = input grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 406173a18..0ee3b4fcb 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -469,7 +469,8 @@ class Linear1D_Col(ParallelLayer): if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) + # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) + self.out_features_per_partition = out_features # Parameters. # Initialize weight. @@ -612,7 +613,8 @@ class Linear1D_Row(ParallelLayer): raise ValueError('cannot skip bias addition if bias is None') # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) + self.input_size_per_partition = in_features # Parameters. # Initialize weight. @@ -884,7 +886,8 @@ class VocabParallelEmbedding1D(ParallelLayer): tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings_per_partition = num_embeddings self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md new file mode 100644 index 000000000..a47e280f2 --- /dev/null +++ b/colossalai/shardformer/README.md @@ -0,0 +1,177 @@ +## ShardFormer + +### Intro +Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy. + +### Quick start +1. Usage +- Use +``` python +from colossalai.shardformer.shard.shardmodel import ShardModel +from transformers import BertForMaskedLM + +# create huggingface model as normal +model = BertForMaskedLM.from_pretrained("bert-base-uncased") + +# make the huggingface model paralleled to ShardModel +# auto policy: +shardmodel = ShardModel(model).model + +# custom policy: +from xxx import +shardmodel = ShardModel(model, ).model + + +# do angthing as normal +... +``` +- Policy + +If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model. + +You should do: + +1. Inherit Policy class +2. Overwrite argument_policy method + - In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers. +3. Overwrite inject_policy method [Optional] + - If you need to modify the forward or backward progress. +4. Overwrite or add the param recording functions + - These function use suffix to record the path of weight or bias for the layer. +5. Overwrite binding + +More details can be found in shardformer/policies/basepolicy.py +``` python +from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument + +CustomPolicy(Policy): + @staticmethod + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: + """ + Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + + Args: + model_config: The config of transformer model + shard_setting: The config of distributed model + + Return: + Dict for the modify policy, + { + origin layer class1 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + origin layer class2 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + ... + } + + """ + raise NotImplementedError + + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + """ + Return the dict for the inject model + + Return: + The injected model, key is the original model and value is the new shardmodel + """ + return () + + @staticmethod + def binding_policy() -> Dict: + """ + Return the dict for the binding model + """ + return NotImplementedError + + @staticmethod + def attn_in() -> List: + """ + Attention qkv layer + + Returns: + List[Layer]: List of layer object, each layer is the new + """ + return NotImplementedError + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def embedding() -> List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def unembedding() -> List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + +``` + +2. Simple example +``` shell +# inference +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference +# train +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train +``` diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 87ed8ac30..6741ae866 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -1,12 +1,14 @@ +from typing import Any, Dict, List, Type + import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from typing import Any, Dict, List, Type - - from transformers import BertForMaskedLM from transformers.models.bert.modeling_bert import MaskedLMOutput + + class BertForMaskedLM_(BertForMaskedLM): + def forward( self, input_ids=None, @@ -23,7 +25,7 @@ class BertForMaskedLM_(BertForMaskedLM): return_dict=None, **kwargs, ): - print("[Inject OK] Injected forward method") + # print("[Inject OK] Injected forward method") return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( @@ -46,9 +48,9 @@ class BertForMaskedLM_(BertForMaskedLM): masked_lm_loss = None # if input_ids is not None: - # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) + # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -60,4 +62,4 @@ class BertForMaskedLM_(BertForMaskedLM): logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 9142e0dae..e096c2b13 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -1,40 +1,47 @@ import torch.nn as nn + def build_policies(): - """ + r""" Build the policies for the model - + Return: The dict for the policies """ auto_policy_dict = {} from transformers.models.bert.modeling_bert import BertForMaskedLM + from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy from transformers.models.bert.modeling_bert import BertForSequenceClassification + from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy - + return auto_policy_dict -def get_autopolicy(model:nn.Module): - """ + +def get_autopolicy(model: nn.Module): + r""" Return the auto policy for the model Args: - model: The model to be used + model (:class:`nn.Module`): The model to get the auto policy Return: - The auto policy for the model + :class:`Policy`: The auto policy for the model """ auto_policy_dict = build_policies() policy = auto_policy_dict.get(model.__class__, None) - if policy is None: - raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + if policy is None: + raise NotImplementedError( + f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" + ) return policy + # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining # model = BertForPreTraining # policy = get_autopolicy(model) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index d444aeb53..a5cc0bc68 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,28 +1,38 @@ # part of code modified from https://github.com/tunib-ai/parallelformers +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Tuple, Type + import torch import torch.nn as nn -import colossalai.nn as col_nn -from typing import Any, Dict, List, Type, Tuple, Callable from transformers import AutoConfig -from dataclasses import dataclass, field + +import colossalai.nn as col_nn + @dataclass class Argument: - attr_dict : Dict[str, Any] - param_funcs : List[Callable] - binding_layers : List[nn.Module] = field(default_factory=list) + r""" + The argument class for the policy + + Args: + attr_dict (Dict[str, Any]): The dict for the param setting + param_funcs (:class:`List[Callable]`): The list for the param functions + """ + attr_dict: Dict[str, Any] + param_funcs: List[Callable] + @dataclass class Layer: - """ + r""" The layer object for the policy Args: - weight: The weight name of the layer - bias: The bias name of the layer - replace_layer: The layer to replace the original layer - ignore: Whether to ignore this layer if it is not in the model + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer + replace_layer (:class:`colosalai.nn`): The layer to replace the original layer + ignore (bool): Whether to ignore this layer if it is not in the model """ weight: str = None bias: str = None @@ -32,45 +42,55 @@ class Layer: @dataclass class Col_Layer(Layer): - """ + r""" Class for col shard layer in MegatronLM + + Args: + gather_output (bool): Whether to gather the output of the layer """ gather_output: bool = False @dataclass class Row_Layer(Layer): - """ + r""" Class for col shard layer in MegatronLM """ pass class Policy(): - """ + r""" The base class for all the policies - For each different model, it should have a different policy class, like BertPolicy for Bert Model - or OPTPolicy for OPT model. + For each different model, it should have a different policy class, like BertPolicy for Bert Model + or OPTPolicy for OPT model. AutoPolicy: - shardformer already defined some policies for huggingface model, just set custom_policy = None + Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, - like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, + like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, BertForSequenceClassification, etc., for each different Bert model we difine different policy class - and overwrite the method inject_policy - + and overwrite the method like ``inject_policy`` to modify the forward and backward process. + CustomPolicy: + If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite + all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy`` + class for the example. + """ + @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: - """ - Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + r""" + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer Args: - model_config: The config of transformer model - shard_setting: The config of distributed model - + model_config (:class:`tansformer.Config`): The config of transformer model + shard_config (:class:`ShardConfig`): The config for sharding model + Return: Dict for the modify policy, + :: { origin layer class1 (nn.Module): Argument( attr_dict = { @@ -101,33 +121,51 @@ class Policy(): """ raise NotImplementedError - @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - """ - Return the dict for the inject model + r""" + Return the dict for the inject model Return: The injected model, key is the original model and value is the new shardmodel + :: + (OrignModel, CustomModel) + in `CustomModel`, we can overwrite the forward and backward process """ return () - @staticmethod - def attn_in() -> List: + def binding_policy() -> Dict: + r""" + Return the dict for the binding model + + Return: + This method should return the binding relationship for some layers share the weight or bias, + the key and value is the suffix of the weight or bias of the model + :: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } """ + return NotImplementedError + + @staticmethod + def attn_in() -> List: + r""" Attention qkv layer + In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be + ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters + in ``Layer`` object can refer to the ``Layer`` class. Returns: - List[Layer]: List of layer object, each layer is the new + List[Layer]: List of layer object, each layer is the new """ return NotImplementedError - @staticmethod def attn_out() -> List: - """ + r""" Attention output projection layer Returns: @@ -135,46 +173,40 @@ class Policy(): """ return NotImplementedError - @staticmethod def mlp_in() -> List: - """ + r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ return NotImplementedError - @staticmethod def mlp_out() -> List: - """ + r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ return NotImplementedError - - + @staticmethod - def embedding()->List: - """ + def embedding() -> List: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object """ return NotImplementedError - - + @staticmethod - def unembedding()->List: - """ + def unembedding() -> List: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 24b95e827..5d91d8ddc 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,56 +1,57 @@ -from typing import Dict, List, Tuple, Type, Any, Callable +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple, Type + import torch.nn as nn -from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer +from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead + import colossalai.nn as col_nn -from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead -from dataclasses import dataclass + +from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer class BertPolicy(Policy): + @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: return { - BertLayer: Argument( - attr_dict = { - # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, - # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, - - }, - param_funcs = [ - BertPolicy.attn_in, - BertPolicy.attn_out, - BertPolicy.mlp_in, - BertPolicy.mlp_out - ] - ), - BertEmbeddings: Argument( - attr_dict = { - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, - }, - param_funcs = [ - BertPolicy.embedding, - ], - binding_layers = [ - BertLMPredictionHead, - ] - ), - BertLMPredictionHead: Argument( - attr_dict = { - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs = [ - BertPolicy.unembedding, - ] - ) + BertLayer: + Argument( + attr_dict={ + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + }, + param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]), + BertEmbeddings: + Argument( + attr_dict={ + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, + }, + param_funcs=[ + BertPolicy.embedding, + ]), + BertLMPredictionHead: + Argument( + attr_dict={ + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + }, + param_funcs=[ + BertPolicy.unembedding, + ]) + } + + @staticmethod + def binding_policy() -> Dict: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } @staticmethod @@ -89,9 +90,8 @@ class BertPolicy(Policy): replace_layer=col_nn.Linear1D_Col, ignore=True, ), - ] - + @staticmethod def attn_out() -> List: return [ @@ -107,17 +107,17 @@ class BertPolicy(Policy): ignore=True, ), ] - + @staticmethod def mlp_in() -> List: return [ - Col_Layer( + Col_Layer( weight="intermediate.dense.weight", bias="intermediate.dense.bias", replace_layer=col_nn.Linear1D_Col, ), ] - + @staticmethod def mlp_out() -> List: return [ @@ -130,13 +130,11 @@ class BertPolicy(Policy): @staticmethod def embedding() -> List: - return [ - Col_Layer( - weight="word_embeddings.weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - ) - ] - + return [Col_Layer( + weight="word_embeddings.weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + )] + @staticmethod def unembedding() -> List: return [ @@ -148,16 +146,21 @@ class BertPolicy(Policy): ) ] + from transformers import BertForMaskedLM + from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ + + class BertForMaskedLMPolicy(BertPolicy): + @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: return (BertForMaskedLM, BertForMaskedLM_) - - + class BertForSequenceClassificationPolicy(BertPolicy): + @staticmethod def inject_policy() -> Dict: return {} @@ -165,4 +168,4 @@ class BertForSequenceClassificationPolicy(BertPolicy): # model = BertForMaskedLM.from_pretrained("bert-base-uncased") # _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) \ No newline at end of file +# print(isinstance(model,list(_.inject_policy().keys())[0])) diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py index be265ff0c..c6a2513a6 100644 --- a/colossalai/shardformer/shard/shardconfig.py +++ b/colossalai/shardformer/shard/shardconfig.py @@ -10,9 +10,9 @@ class ShardConfig: fp16: bool = True num_gpus: int = 2 world_size: int = 2 - backend="nccl" + backend = "nccl" verbose: str = 'simple' seed: int = None require_grad: bool = False master_addr: str = "127.0.0.1" - master_port: int = 29500 \ No newline at end of file + master_port: int = 29500 diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ef785cfee..2f6bb4265 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,56 +1,59 @@ +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union + import torch import torch.nn as nn -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable -from .shardconfig import ShardConfig -from dataclasses import dataclass -from ..policies.basepolicy import Policy, Layer -from ..policies.autopolicy import get_autopolicy -from .slicer import Slicer -from ..utils.utils import hasattr_, setattr_, getattr_ + import colossalai.nn as col_nn from colossalai.logging import get_dist_logger -import os +from ..policies.autopolicy import get_autopolicy +from ..policies.basepolicy import Layer, Policy +from ..utils.utils import getattr_, hasattr_, setattr_ +from .shardconfig import ShardConfig +from .slicer import Slicer logger = get_dist_logger() + class ModelSharder(object): - """ + r""" Shard the original huggingface model according to the policy Args: - policy: The policy to shard the model - model: The model to shard - dist_setting: The setting of distributed model + policy (:class:`Policy`): The policy to shard the model + model (:class:`torch.Module`): The model to shard + shard_config: The setting of distributed model """ + def __init__( self, model: nn.Module, policy: Policy, - shard_config: ShardConfig = None, # TODO - ) -> None: + shard_config: ShardConfig = None, # TODO + ) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy self.slicer = Slicer(shard_config) self.shard_config = shard_config self.model_config = self.model.config - self.binding_map = {} - def shard(self) -> None: self.inject_model(self.model) self.replace_layer(self.model) - - + self.bind_layer(self.model) + def inject_model( - self, - model: nn.Module, - ) -> None: - """ + self, + model: nn.Module, + ) -> None: + r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model - + e.g. + :: BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_policy = self.policy.inject_policy() @@ -64,49 +67,43 @@ class ModelSharder(object): setattr( model.__class__, key, - getattr(shard_model_cls,key), + getattr(shard_model_cls, key), ) else: raise NotImplementedError(f"{model.__class__} is not implemented so far") - def replace_layer( - self, - model: nn.Module, - ) -> None: - """ + self, + model: nn.Module, + ) -> None: + r""" Replace the layer according to the policy, and replace the layer one by one Args: - layer: The layer to shard + model (:class:`torch.nn.Module`): The layer to shard """ argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) for argument_policy in argument_policies.items(): origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs - binding_layers = argument_policy[1].binding_layers - # if binding_layer is not None: - # self.binding_map[origin_layer_cls] = binding_layer - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers) - + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) def reverse_replace_layer( - self, - layer: nn.Module, - origin_cls: nn.Module, - attr_dict: Dict[str, Any], - param_funcs: List[Callable], - binding_layers: List[nn.Module] - ) -> None: - """ + self, + layer: nn.Module, + origin_cls: nn.Module, + attr_dict: Dict[str, Any], + param_funcs: List[Callable], + ) -> None: + r""" Reverse the replace layer operation Args: - layer: The object of layer to shard - origin_cls: The origin layer class - attr_dict: The attribute dict to modify - policy_cls: The policy class + layer (:class:`torch.nn.Module`): The object of layer to shard + origin_cls (:class:`transformers.model`): The origin layer class + attr_dict (Dict): The attribute dict to modify + policy_cls (:class:`Policy`): The policy class """ for name, child in layer.named_children(): if child.__class__ == origin_cls: @@ -115,25 +112,23 @@ class ModelSharder(object): setattr_(child, k, v, ignore=True) # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, param_funcs, binding_layers) + self.shard_one_layer(child, param_funcs) continue - self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer - def shard_one_layer( - self, - org_layer: nn.Module, - param_funcs: List[Callable], - binding_layers: List[nn.Module] - ) -> None: - """ + self, + org_layer: nn.Module, + param_funcs: List[Callable], + ) -> None: + r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: - org_layer: The origin layer object to shard - param_funcs: The function list to get shard information in policy class + org_layer (:class:`torch.nn.Module`): The origin layer object to shard + param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class """ # print(org_layer) @@ -148,7 +143,7 @@ class ModelSharder(object): ignore = policy_layer.ignore if policy_layer.__class__.__name__ == "Col_Layer": gather_output = policy_layer.gather_output - print(gather_output) + # print(gather_output) if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -172,67 +167,81 @@ class ModelSharder(object): # slice weight and bias weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) - print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) - # save the binding information - for binding_layer in binding_layers: - self.binding_map[binding_layer] = dict(weight=weight, bias=bias) + # print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) # create new object to replace the origin layer if replace_layer_cls is not None: # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") if isinstance(getattr_(org_layer, layer_attr), nn.Linear): if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) + replace_layer = replace_layer_cls(weight.shape[1], + weight.shape[0], + bias=False if bias is None else True) elif replace_layer_cls.__name__ == "Linear1D_Col": - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + bias=False if bias is None else True, + gather_output=gather_output) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) + elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], + getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) else: - raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") + raise NotImplementedError( + f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") # do not replace the layer object, just replace the weight and bias else: self.set_param(org_layer, layer_attr, weight, bias) - - def set_param( - self, - layer: Any, - layer_attr: str = "", - weight: torch.Tensor = None, - bias: torch.Tensor = None - ) -> None: - """ + def set_param(self, + layer: Any, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + layer_attr: str = "") -> None: + r""" Reset the weight and bias of the layer object Args: - layer: The layer object - layer_attr: The attribute name of the layer - weight: The weight of the layer - bias: The bias of the layer + layer (:class:`torch.nn.Module`): The layer object + layer_attr (str): The attribute name of the layer + weight (:class:`torch.Tensor`): The weight of the layer + bias (:class:`torch.Tensor`): The bias of the layer """ assert weight is not None or bias is not None if weight is not None: - setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) + setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous())) self.set_layer_size(layer, layer_attr, weight.shape) if bias is not None: - setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) - + setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous())) def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: - """ + r""" Set the layer attribute Args: - layer: The layer object - layer_attr: The attribute name of the layer - size: Torch.size + layer (:class:`torch.nn.Module`): The layer object + layer_attr (str): The attribute name of the layer + size (:class:`torch.Size`): The size of the tensor """ # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features attrs = ["out_features", "in_features"] for i, attr in enumerate(attrs): if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) + setattr_(layer, f"{layer_attr}.{attr}", size[i]) + + def bind_layer(self, model: nn.Module) -> None: + r""" + Bind the layer according to the binding policy + + Args: + model (:class:`torch.nn.Module`): The shard model + """ + binding_map = self.policy.binding_policy() + for k, v in binding_map.items(): + param = getattr_(model, k) + param = nn.Parameter(param) + setattr_(model, k, param) + setattr_(model, v, param) diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py index 54d7b5ba0..7e7d1576a 100644 --- a/colossalai/shardformer/shard/shardmodel.py +++ b/colossalai/shardformer/shard/shardmodel.py @@ -1,46 +1,48 @@ import os +from contextlib import suppress +from dataclasses import dataclass + import torch +import torch.distributed as dist import torch.nn as nn import transformers -import torch.distributed as dist -from dataclasses import dataclass -from contextlib import suppress from colossalai.tensor.d_tensor.layout import Layout + from ..policies.basepolicy import Policy -from .sharder import ModelSharder from .shardconfig import ShardConfig +from .sharder import ModelSharder class ShardModel(object): - """ - The class for sharding the huggingface model, self.model is the sharded model + r""" + The class for sharding the huggingface model, ''self.model'' is the sharded model Just creat a new ShardModel object to shard huggingface model Args: - model: the origin huggingface model - dist_config: the config for distribute information - custom_policy: the custom policy for sharding + model (:class:`torch.nn.Model`): the origin huggingface model + dist_config (:class:`ShardConfig`): the config for distribute information + custom_policy (:class:`Policy`): the custom policy for sharding """ + def __init__( - self, - model: nn.Module, - shard_config: ShardConfig = None, # TODO - custom_policy: Policy = None, - ) -> None: + self, + model: nn.Module, + shard_config: ShardConfig = None, # TODO + custom_policy: Policy = None, + ) -> None: self.model = model self.shard_config = shard_config self.policy = custom_policy # self.layout=, # TODO - sharder=ModelSharder( + sharder = ModelSharder( model=self.model, policy=self.policy, shard_config=self.shard_config, ) sharder.shard() - def set_environ(self) -> None: os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" @@ -55,4 +57,4 @@ class ShardModel(object): torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) def back_to_org() -> None: - pass \ No newline at end of file + pass diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 1849cdc99..096f5db95 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,40 +1,40 @@ import os -from typing import Dict, Tuple from dataclasses import dataclass +from typing import Dict, Tuple import torch import torch.distributed as dist -from ..policies.basepolicy import Layer, Col_Layer, Row_Layer -from .shardconfig import ShardConfig +from ..policies.basepolicy import Col_Layer, Layer, Row_Layer +from .shardconfig import ShardConfig dim_mapping = {Col_Layer: 1, Row_Layer: 0} + class Slicer(): def __init__( - self, - shardconfig: ShardConfig #TODO + self, + shardconfig: ShardConfig #TODO ) -> None: self.shardconfig = shardconfig - def slice_weight_bias( self, weight: torch.Tensor, bias: torch.Tensor, policy_layer_cls: Layer, ): - """ + r""" Slice the weight and bias according to policy layer cls - Layer -> do nothing - Col_Layer -> slice the weight and bias along dim 1 - Row_Layer -> slice the weight along dim 0 and do not slice bias + ``Layer`` -> do nothing + ``Col_Layer`` -> slice the weight and bias along dim 1 + ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias Args: - weight: The weight of the layer - bias: The bias of the layer - policy_layer_class: The class represent how to slice the tensor + weight (:class:`torch.nn.Module`): The weight of the layer + bias: (:class:`torch.nn.Module`): The bias of the layer + policy_layer_class (:class:`Policy`): The class represent how to slice the tensor """ if policy_layer_cls == Layer: return weight, bias @@ -46,42 +46,6 @@ class Slicer(): else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") return weight, bias - - - def slice_weight( - self, - weight: torch.Tensor, - policy_layer_cls: Layer, - ) -> torch.Tensor: - """ - Slice the weight and bias according to the shardconfig - - Args: - weight: The weight of the layer - bias: The bias of the layer - policy_layer_class: The class represent how to slice the tensor - """ - if weight is not None: - dim = dim_mapping[policy_layer_cls] - weight = self.slice_tensor(weight, dim, False) - return weight - - - def slice_bias( - self, - bias: torch.Tensor, - ) -> torch.Tensor: - """ - Slice the bias according to the shardconfig - - Args: - bias: The bias of the layer - """ - assert bias is not None, "The bias is None" - if bias is not None: - bias = self.slice_tensor(bias, 1, True) - return bias - def slice_tensor( self, @@ -89,8 +53,13 @@ class Slicer(): dim: int, is_bias: bool, ) -> torch.Tensor: - """ + r""" Slice tensor according to the config + + Args: + tensor_in (:class:`torch.Tensor`): The tensor to slice + dim (int): The dimension to slice + is_bias (bool): Whether the tensor is bias """ if tensor_in is None: return None @@ -99,69 +68,75 @@ class Slicer(): else: return self.slice_1d(tensor_in) - def slice_2d( self, tensor: torch.Tensor, dim: int, ) -> torch.Tensor: - """ - Slice the 2D tensor + r""" + Slice the 2D tensor Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + dim (int): The dimension to slice """ - assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" + assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" if dim == 0: return self.slice_row(tensor) elif dim == 1: return self.slice_col(tensor) - def slice_1d( self, tensor: torch.Tensor, - dim: int = None, ) -> torch.Tensor: - """ - Slice the 1D tensor + r""" + Slice the 1D tensor Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[down_idx:up_idx] + return tensor[down_idx:up_idx].contiguous() def slice_col( self, tensor: torch.Tensor, ) -> torch.Tensor: - """ + r""" Slice the tensor in column Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor + """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[down_idx:up_idx,:] - + return tensor[down_idx:up_idx, :].contiguous() def slice_row( self, tensor: torch.Tensor, ) -> torch.Tensor: - """ + r""" Slice the tensor in column Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor """ delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[:,down_idx:up_idx] - \ No newline at end of file + return tensor[:, down_idx:up_idx].contiguous() diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py index 295529429..2b80d8b3c 100644 --- a/colossalai/shardformer/test/config.py +++ b/colossalai/shardformer/test/config.py @@ -1,5 +1 @@ -parallel = dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d') -) \ No newline at end of file +parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index c2a9053ca..0cdc6ef38 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,23 +1,51 @@ -from transformers import AutoTokenizer -from transformers import BertForMaskedLM +import argparse +import inspect +import os + +import torch +import torch.nn as nn +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments + import colossalai -from colossalai.shardformer.shard.shardmodel import ShardModel -from colossalai.utils import get_current_device, print_rank_0 from colossalai.logging import get_dist_logger from colossalai.shardformer.shard.shardconfig import ShardConfig -import inspect -import argparse -import torch.nn as nn -import os +from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.utils import get_current_device, print_rank_0 +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + def get_args(): parser = colossalai.get_default_parser() + parser.add_argument("--mode", type=str, default='inference') return parser.parse_args() + +def load_data(): + datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') + # datasets=load_dataset("yelp_review_full") + tokenized_datasets = datasets.map( + lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) + tokenized_datasets = tokenized_datasets.remove_columns(["text"]) + # tokenized_datasets=tokenized_datasets.rename_column("label","labels") + tokenized_datasets.set_format("torch") + + train_dataset = tokenized_datasets["train"].select(range(500)) + test_dataset = tokenized_datasets["test"].select(range(100)) + + datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") + train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + return train_dataloader, eval_dataloader + + def inference(model: nn.Module): - # print(model) + print(model) + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") @@ -25,13 +53,48 @@ def inference(model: nn.Module): outputs = model(**inputs) print(outputs) + +def train(model: nn.Module, num_epoch: int = 2): + train_dataloader, eval_dataloader = load_data() + optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) + progress_bar = tqdm(range((num_epoch) * len(train_dataloader))) + criterion = nn.CrossEntropyLoss() + model.to("cuda") + model.train() + for epoch in range(num_epoch): + progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") + + for batch in train_dataloader: + optimizer.zero_grad() + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + progress_bar.update(1) + train_loss = loss + + loss = 0.0 + for batch in eval_dataloader: + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + # loss = outputs.loss + loss += outputs.loss.item() + # loss = criterion(outputs.logits, batch["input_ids"]) + test_loss = loss / len(eval_dataloader) + print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") + + if __name__ == "__main__": args = get_args() colossalai.launch_from_torch(config=args.config) model = BertForMaskedLM.from_pretrained("bert-base-uncased") shard_config = ShardConfig( - rank = int(str(get_current_device()).split(':')[-1]), - world_size= int(os.environ['WORLD_SIZE']), + rank=int(str(get_current_device()).split(':')[-1]), + world_size=int(os.environ['WORLD_SIZE']), ) shardmodel = ShardModel(model, shard_config) - inference(shardmodel.model) + if args.mode == "train": + train(shardmodel.model) + elif args.mode == "inference": + inference(shardmodel.model) diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 5eba87f6f..eb84edd88 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -1,10 +1,10 @@ def hasattr_(obj, attr: str): - """ + r""" Check whether the object has the multi sublevel attr Args: - obj: The object to check - attr: The multi level attr to check + obj (object): The object to check + attr (str): The multi level attr to check """ attrs = attr.split('.') for a in attrs: @@ -14,15 +14,16 @@ def hasattr_(obj, attr: str): return False return True -def setattr_(obj, attr: str, value, ignore: bool=False): - """ + +def setattr_(obj, attr: str, value, ignore: bool = False): + r""" Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist Args: - obj: The object to set - attr: The multi level attr to set - value: The value to set - ignore: Whether to ignore when the attr doesn't exist + obj (object): The object to set + attr (str): The multi level attr to set + value (Any): The value to set + ignore (bool): Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -31,18 +32,19 @@ def setattr_(obj, attr: str, value, ignore: bool=False): obj = getattr(obj, a) except AttributeError: if ignore: - return + return raise AttributeError(f"Object {obj} has no attribute {attr}") setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str, ignore: bool=None): - """ + +def getattr_(obj, attr: str, ignore: bool = None): + r""" Get the object's multi sublevel attr - + Args: - obj: The object to set - attr: The multi level attr to set - ignore: Whether to ignore when the attr doesn't exist + obj (object): The object to set + attr (str): The multi level attr to set + ignore (bool): Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -53,4 +55,4 @@ def getattr_(obj, attr: str, ignore: bool=None): if ignore: return None raise AttributeError(f"Object {obj} has no attribute {attr}") - return obj \ No newline at end of file + return obj