From 8d68de767d156012924cdbcba318ac8d85bd72a7 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 22 May 2023 15:02:17 +0800 Subject: [PATCH] [shardformer] init shardformer code structure (#3731) * init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test example --- colossalai/shardformer/__init__.py | 0 colossalai/shardformer/model/__init__.py | 0 colossalai/shardformer/model/modeling_bert.py | 63 +++++ colossalai/shardformer/policies/__init__.py | 0 colossalai/shardformer/policies/autopolicy.py | 41 +++ colossalai/shardformer/policies/basepolicy.py | 182 ++++++++++++++ colossalai/shardformer/policies/bert.py | 168 +++++++++++++ colossalai/shardformer/shard/__init__.py | 0 colossalai/shardformer/shard/shardconfig.py | 18 ++ colossalai/shardformer/shard/sharder.py | 238 ++++++++++++++++++ colossalai/shardformer/shard/shardmodel.py | 58 +++++ colossalai/shardformer/shard/slicer.py | 167 ++++++++++++ colossalai/shardformer/test/config.py | 5 + colossalai/shardformer/test/test.py | 37 +++ colossalai/shardformer/utils/__init__.py | 0 colossalai/shardformer/utils/utils.py | 56 +++++ 16 files changed, 1033 insertions(+) create mode 100644 colossalai/shardformer/__init__.py create mode 100644 colossalai/shardformer/model/__init__.py create mode 100644 colossalai/shardformer/model/modeling_bert.py create mode 100644 colossalai/shardformer/policies/__init__.py create mode 100644 colossalai/shardformer/policies/autopolicy.py create mode 100644 colossalai/shardformer/policies/basepolicy.py create mode 100644 colossalai/shardformer/policies/bert.py create mode 100644 colossalai/shardformer/shard/__init__.py create mode 100644 colossalai/shardformer/shard/shardconfig.py create mode 100644 colossalai/shardformer/shard/sharder.py create mode 100644 colossalai/shardformer/shard/shardmodel.py create mode 100644 colossalai/shardformer/shard/slicer.py create mode 100644 colossalai/shardformer/test/config.py create mode 100644 colossalai/shardformer/test/test.py create mode 100644 colossalai/shardformer/utils/__init__.py create mode 100644 colossalai/shardformer/utils/utils.py diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/shardformer/model/__init__.py b/colossalai/shardformer/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py new file mode 100644 index 000000000..87ed8ac30 --- /dev/null +++ b/colossalai/shardformer/model/modeling_bert.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from typing import Any, Dict, List, Type + + +from transformers import BertForMaskedLM +from transformers.models.bert.modeling_bert import MaskedLMOutput +class BertForMaskedLM_(BertForMaskedLM): + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + print("[Inject OK] Injected forward method") + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + + # if input_ids is not None: + # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/colossalai/shardformer/policies/__init__.py b/colossalai/shardformer/policies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py new file mode 100644 index 000000000..9142e0dae --- /dev/null +++ b/colossalai/shardformer/policies/autopolicy.py @@ -0,0 +1,41 @@ +import torch.nn as nn + +def build_policies(): + """ + Build the policies for the model + + Return: + The dict for the policies + """ + auto_policy_dict = {} + + from transformers.models.bert.modeling_bert import BertForMaskedLM + from .bert import BertForMaskedLMPolicy + auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy + + from transformers.models.bert.modeling_bert import BertForSequenceClassification + from .bert import BertForSequenceClassificationPolicy + auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + + return auto_policy_dict + +def get_autopolicy(model:nn.Module): + """ + Return the auto policy for the model + + Args: + model: The model to be used + + Return: + The auto policy for the model + """ + auto_policy_dict = build_policies() + policy = auto_policy_dict.get(model.__class__, None) + if policy is None: + raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + return policy + +# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining +# model = BertForPreTraining +# policy = get_autopolicy(model) +# print(policy) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py new file mode 100644 index 000000000..d444aeb53 --- /dev/null +++ b/colossalai/shardformer/policies/basepolicy.py @@ -0,0 +1,182 @@ +# part of code modified from https://github.com/tunib-ai/parallelformers + +import torch +import torch.nn as nn +import colossalai.nn as col_nn +from typing import Any, Dict, List, Type, Tuple, Callable +from transformers import AutoConfig +from dataclasses import dataclass, field + +@dataclass +class Argument: + attr_dict : Dict[str, Any] + param_funcs : List[Callable] + binding_layers : List[nn.Module] = field(default_factory=list) + +@dataclass +class Layer: + """ + The layer object for the policy + + Args: + weight: The weight name of the layer + bias: The bias name of the layer + replace_layer: The layer to replace the original layer + ignore: Whether to ignore this layer if it is not in the model + """ + weight: str = None + bias: str = None + replace_layer: Any = None + ignore: bool = False + + +@dataclass +class Col_Layer(Layer): + """ + Class for col shard layer in MegatronLM + """ + gather_output: bool = False + + +@dataclass +class Row_Layer(Layer): + """ + Class for col shard layer in MegatronLM + """ + pass + + +class Policy(): + """ + The base class for all the policies + For each different model, it should have a different policy class, like BertPolicy for Bert Model + or OPTPolicy for OPT model. + AutoPolicy: + shardformer already defined some policies for huggingface model, just set custom_policy = None + to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, + like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, + BertForSequenceClassification, etc., for each different Bert model we difine different policy class + and overwrite the method inject_policy + + CustomPolicy: + """ + @staticmethod + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: + """ + Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + + Args: + model_config: The config of transformer model + shard_setting: The config of distributed model + + Return: + Dict for the modify policy, + { + origin layer class1 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + origin layer class2 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + ... + } + + """ + raise NotImplementedError + + + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + """ + Return the dict for the inject model + + Return: + The injected model, key is the original model and value is the new shardmodel + """ + return () + + + @staticmethod + def attn_in() -> List: + """ + Attention qkv layer + + Returns: + List[Layer]: List of layer object, each layer is the new + """ + return NotImplementedError + + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def embedding()->List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def unembedding()->List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py new file mode 100644 index 000000000..24b95e827 --- /dev/null +++ b/colossalai/shardformer/policies/bert.py @@ -0,0 +1,168 @@ +from typing import Dict, List, Tuple, Type, Any, Callable +import torch.nn as nn +from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer +import colossalai.nn as col_nn +from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead +from dataclasses import dataclass + + +class BertPolicy(Policy): + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: + return { + BertLayer: Argument( + attr_dict = { + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + + }, + param_funcs = [ + BertPolicy.attn_in, + BertPolicy.attn_out, + BertPolicy.mlp_in, + BertPolicy.mlp_out + ] + ), + BertEmbeddings: Argument( + attr_dict = { + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, + }, + param_funcs = [ + BertPolicy.embedding, + ], + binding_layers = [ + BertLMPredictionHead, + ] + ), + BertLMPredictionHead: Argument( + attr_dict = { + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + }, + param_funcs = [ + BertPolicy.unembedding, + ] + ) + } + + @staticmethod + def attn_in() -> List: + return [ + Col_Layer( + weight="attention.self.query.weight", + bias="attention.self.query.bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + weight="attention.self.key.weight", + bias="attention.self.key.bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + weight="attention.self.value.weight", + bias="attention.self.value.bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + weight="crossattention.self.query.weight", + bias="crossattention.self.query.bias", + replace_layer=col_nn.Linear1D_Col, + ignore=True, + ), + Col_Layer( + weight="crossattention.self.key.weight", + bias="crossattention.self.key.bias", + replace_layer=col_nn.Linear1D_Col, + ignore=True, + ), + Col_Layer( + weight="crossattention.self.value.weight", + bias="crossattention.self.value.bias", + replace_layer=col_nn.Linear1D_Col, + ignore=True, + ), + + ] + + @staticmethod + def attn_out() -> List: + return [ + Row_Layer( + weight="attention.output.dense.weight", + bias="attention.output.dense.bias", + replace_layer=col_nn.Linear1D_Row, + ), + Row_Layer( + weight="crossattention.output.dense.weight", + bias="crossattention.output.dense.bias", + replace_layer=col_nn.Linear1D_Row, + ignore=True, + ), + ] + + @staticmethod + def mlp_in() -> List: + return [ + Col_Layer( + weight="intermediate.dense.weight", + bias="intermediate.dense.bias", + replace_layer=col_nn.Linear1D_Col, + ), + ] + + @staticmethod + def mlp_out() -> List: + return [ + Row_Layer( + weight="output.dense.weight", + bias="output.dense.bias", + replace_layer=col_nn.Linear1D_Row, + ), + ] + + @staticmethod + def embedding() -> List: + return [ + Col_Layer( + weight="word_embeddings.weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + ) + ] + + @staticmethod + def unembedding() -> List: + return [ + Col_Layer( + weight="decoder.weight", + bias="decoder.bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + +from transformers import BertForMaskedLM +from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ +class BertForMaskedLMPolicy(BertPolicy): + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + return (BertForMaskedLM, BertForMaskedLM_) + + + +class BertForSequenceClassificationPolicy(BertPolicy): + @staticmethod + def inject_policy() -> Dict: + return {} + + +# model = BertForMaskedLM.from_pretrained("bert-base-uncased") +# _ = BertForMaskedLMPolicy(model) +# print(isinstance(model,list(_.inject_policy().keys())[0])) \ No newline at end of file diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py new file mode 100644 index 000000000..be265ff0c --- /dev/null +++ b/colossalai/shardformer/shard/shardconfig.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + + +@dataclass +class ShardConfig: + """ + The config for sharding the huggingface model for test + """ + rank: int + fp16: bool = True + num_gpus: int = 2 + world_size: int = 2 + backend="nccl" + verbose: str = 'simple' + seed: int = None + require_grad: bool = False + master_addr: str = "127.0.0.1" + master_port: int = 29500 \ No newline at end of file diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py new file mode 100644 index 000000000..ef785cfee --- /dev/null +++ b/colossalai/shardformer/shard/sharder.py @@ -0,0 +1,238 @@ +import torch +import torch.nn as nn +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable +from .shardconfig import ShardConfig +from dataclasses import dataclass +from ..policies.basepolicy import Policy, Layer +from ..policies.autopolicy import get_autopolicy +from .slicer import Slicer +from ..utils.utils import hasattr_, setattr_, getattr_ +import colossalai.nn as col_nn +from colossalai.logging import get_dist_logger +import os + + +logger = get_dist_logger() + +class ModelSharder(object): + """ + Shard the original huggingface model according to the policy + + Args: + policy: The policy to shard the model + model: The model to shard + dist_setting: The setting of distributed model + """ + def __init__( + self, + model: nn.Module, + policy: Policy, + shard_config: ShardConfig = None, # TODO + ) -> None: + self.model = model + self.policy = get_autopolicy(self.model) if policy is None else policy + self.slicer = Slicer(shard_config) + self.shard_config = shard_config + self.model_config = self.model.config + self.binding_map = {} + + + def shard(self) -> None: + self.inject_model(self.model) + self.replace_layer(self.model) + + + def inject_model( + self, + model: nn.Module, + ) -> None: + """ + Replace the model to policy defined model + Mainly modify the forward and backward to fit distributed model + + e.g. + BertForMaskedLM.forward -> BertForMaskedLM_.forward + """ + inject_policy = self.policy.inject_policy() + + org_model_cls = inject_policy[0] + shard_model_cls = inject_policy[1] + + if model.__class__ == org_model_cls: + for key in shard_model_cls.__dict__.keys(): + if hasattr(model.__class__, key): + setattr( + model.__class__, + key, + getattr(shard_model_cls,key), + ) + else: + raise NotImplementedError(f"{model.__class__} is not implemented so far") + + + def replace_layer( + self, + model: nn.Module, + ) -> None: + """ + Replace the layer according to the policy, and replace the layer one by one + + Args: + layer: The layer to shard + """ + argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) + for argument_policy in argument_policies.items(): + origin_layer_cls = argument_policy[0] + attr_dict = argument_policy[1].attr_dict + param_funcs = argument_policy[1].param_funcs + binding_layers = argument_policy[1].binding_layers + # if binding_layer is not None: + # self.binding_map[origin_layer_cls] = binding_layer + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers) + + + def reverse_replace_layer( + self, + layer: nn.Module, + origin_cls: nn.Module, + attr_dict: Dict[str, Any], + param_funcs: List[Callable], + binding_layers: List[nn.Module] + ) -> None: + """ + Reverse the replace layer operation + + Args: + layer: The object of layer to shard + origin_cls: The origin layer class + attr_dict: The attribute dict to modify + policy_cls: The policy class + """ + for name, child in layer.named_children(): + if child.__class__ == origin_cls: + # replac_layer = child + for k, v in attr_dict.items(): + setattr_(child, k, v, ignore=True) + # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) + # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) + self.shard_one_layer(child, param_funcs, binding_layers) + continue + + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) + return layer + + + def shard_one_layer( + self, + org_layer: nn.Module, + param_funcs: List[Callable], + binding_layers: List[nn.Module] + ) -> None: + """ + Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict + + Args: + org_layer: The origin layer object to shard + param_funcs: The function list to get shard information in policy class + + """ + # print(org_layer) + for func in param_funcs: + policy_layers = func() + for policy_layer in policy_layers: + weight = None + bias = None + weight_attr = policy_layer.weight + bias_attr = policy_layer.bias + replace_layer_cls = policy_layer.replace_layer + ignore = policy_layer.ignore + if policy_layer.__class__.__name__ == "Col_Layer": + gather_output = policy_layer.gather_output + print(gather_output) + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # dont have the attribute in policy, and ignore is true + if weight is None and bias is None and ignore: + continue + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) + + # slice weight and bias + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) + print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) + # save the binding information + for binding_layer in binding_layers: + self.binding_map[binding_layer] = dict(weight=weight, bias=bias) + + # create new object to replace the origin layer + if replace_layer_cls is not None: + # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") + if isinstance(getattr_(org_layer, layer_attr), nn.Linear): + if replace_layer_cls.__name__ == "Linear1D_Row": + replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) + elif replace_layer_cls.__name__ == "Linear1D_Col": + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + else: + raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") + # do not replace the layer object, just replace the weight and bias + else: + self.set_param(org_layer, layer_attr, weight, bias) + + + def set_param( + self, + layer: Any, + layer_attr: str = "", + weight: torch.Tensor = None, + bias: torch.Tensor = None + ) -> None: + """ + Reset the weight and bias of the layer object + + Args: + layer: The layer object + layer_attr: The attribute name of the layer + weight: The weight of the layer + bias: The bias of the layer + """ + assert weight is not None or bias is not None + if weight is not None: + setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) + self.set_layer_size(layer, layer_attr, weight.shape) + if bias is not None: + setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) + + + def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: + """ + Set the layer attribute + + Args: + layer: The layer object + layer_attr: The attribute name of the layer + size: Torch.size + """ + # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features + attrs = ["out_features", "in_features"] + for i, attr in enumerate(attrs): + if hasattr_(layer, f"{layer_attr}.{attr}"): + setattr_(layer, f"{layer_attr}.{attr}", size[i]) diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py new file mode 100644 index 000000000..54d7b5ba0 --- /dev/null +++ b/colossalai/shardformer/shard/shardmodel.py @@ -0,0 +1,58 @@ +import os +import torch +import torch.nn as nn +import transformers +import torch.distributed as dist +from dataclasses import dataclass +from contextlib import suppress + +from colossalai.tensor.d_tensor.layout import Layout +from ..policies.basepolicy import Policy +from .sharder import ModelSharder +from .shardconfig import ShardConfig + + +class ShardModel(object): + """ + The class for sharding the huggingface model, self.model is the sharded model + Just creat a new ShardModel object to shard huggingface model + + Args: + model: the origin huggingface model + dist_config: the config for distribute information + custom_policy: the custom policy for sharding + """ + def __init__( + self, + model: nn.Module, + shard_config: ShardConfig = None, # TODO + custom_policy: Policy = None, + ) -> None: + self.model = model + self.shard_config = shard_config + self.policy = custom_policy + # self.layout=, # TODO + + sharder=ModelSharder( + model=self.model, + policy=self.policy, + shard_config=self.shard_config, + ) + sharder.shard() + + + def set_environ(self) -> None: + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" + os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr) + os.environ["MASTER_PORT"] = str(self.dist_config.master_port) + os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus) + os.environ["RANK"] = str(self.dist_config.rank) + os.environ["LOCAL_RANK"] = str(self.dist_config.rank) + if not dist.is_initialized(): + dist.init_process_group(backend=self.dist_config.backend) + + torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) + + def back_to_org() -> None: + pass \ No newline at end of file diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py new file mode 100644 index 000000000..1849cdc99 --- /dev/null +++ b/colossalai/shardformer/shard/slicer.py @@ -0,0 +1,167 @@ +import os +from typing import Dict, Tuple +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from ..policies.basepolicy import Layer, Col_Layer, Row_Layer +from .shardconfig import ShardConfig + + +dim_mapping = {Col_Layer: 1, Row_Layer: 0} + +class Slicer(): + + def __init__( + self, + shardconfig: ShardConfig #TODO + ) -> None: + self.shardconfig = shardconfig + + + def slice_weight_bias( + self, + weight: torch.Tensor, + bias: torch.Tensor, + policy_layer_cls: Layer, + ): + """ + Slice the weight and bias according to policy layer cls + Layer -> do nothing + Col_Layer -> slice the weight and bias along dim 1 + Row_Layer -> slice the weight along dim 0 and do not slice bias + + Args: + weight: The weight of the layer + bias: The bias of the layer + policy_layer_class: The class represent how to slice the tensor + """ + if policy_layer_cls == Layer: + return weight, bias + elif policy_layer_cls == Col_Layer: + weight = self.slice_tensor(weight, 1, False) + bias = self.slice_tensor(bias, 0, True) + elif policy_layer_cls == Row_Layer: + weight = self.slice_tensor(weight, 0, False) + else: + raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") + return weight, bias + + + def slice_weight( + self, + weight: torch.Tensor, + policy_layer_cls: Layer, + ) -> torch.Tensor: + """ + Slice the weight and bias according to the shardconfig + + Args: + weight: The weight of the layer + bias: The bias of the layer + policy_layer_class: The class represent how to slice the tensor + """ + if weight is not None: + dim = dim_mapping[policy_layer_cls] + weight = self.slice_tensor(weight, dim, False) + return weight + + + def slice_bias( + self, + bias: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the bias according to the shardconfig + + Args: + bias: The bias of the layer + """ + assert bias is not None, "The bias is None" + if bias is not None: + bias = self.slice_tensor(bias, 1, True) + return bias + + + def slice_tensor( + self, + tensor_in: torch.Tensor, + dim: int, + is_bias: bool, + ) -> torch.Tensor: + """ + Slice tensor according to the config + """ + if tensor_in is None: + return None + if not is_bias: + return self.slice_2d(tensor_in, dim) + else: + return self.slice_1d(tensor_in) + + + def slice_2d( + self, + tensor: torch.Tensor, + dim: int, + ) -> torch.Tensor: + """ + Slice the 2D tensor + + Args: + tensor: The tensor to slice + """ + assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" + if dim == 0: + return self.slice_row(tensor) + elif dim == 1: + return self.slice_col(tensor) + + + def slice_1d( + self, + tensor: torch.Tensor, + dim: int = None, + ) -> torch.Tensor: + """ + Slice the 1D tensor + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = self.shardconfig.rank * delta + up_idx = down_idx + delta + return tensor[down_idx:up_idx] + + def slice_col( + self, + tensor: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the tensor in column + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = self.shardconfig.rank * delta + up_idx = down_idx + delta + return tensor[down_idx:up_idx,:] + + + def slice_row( + self, + tensor: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the tensor in column + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = self.shardconfig.rank * delta + up_idx = down_idx + delta + return tensor[:,down_idx:up_idx] + \ No newline at end of file diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py new file mode 100644 index 000000000..295529429 --- /dev/null +++ b/colossalai/shardformer/test/config.py @@ -0,0 +1,5 @@ +parallel = dict( + data=1, + pipeline=1, + tensor=dict(size=2, mode='1d') +) \ No newline at end of file diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py new file mode 100644 index 000000000..c2a9053ca --- /dev/null +++ b/colossalai/shardformer/test/test.py @@ -0,0 +1,37 @@ +from transformers import AutoTokenizer +from transformers import BertForMaskedLM +import colossalai +from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.utils import get_current_device, print_rank_0 +from colossalai.logging import get_dist_logger +from colossalai.shardformer.shard.shardconfig import ShardConfig +import inspect +import argparse +import torch.nn as nn +import os + +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +def get_args(): + parser = colossalai.get_default_parser() + return parser.parse_args() + +def inference(model: nn.Module): + # print(model) + token = "Hello, my dog is cute" + inputs = tokenizer(token, return_tensors="pt") + inputs.to("cuda") + model.to("cuda") + outputs = model(**inputs) + print(outputs) + +if __name__ == "__main__": + args = get_args() + colossalai.launch_from_torch(config=args.config) + model = BertForMaskedLM.from_pretrained("bert-base-uncased") + shard_config = ShardConfig( + rank = int(str(get_current_device()).split(':')[-1]), + world_size= int(os.environ['WORLD_SIZE']), + ) + shardmodel = ShardModel(model, shard_config) + inference(shardmodel.model) diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py new file mode 100644 index 000000000..5eba87f6f --- /dev/null +++ b/colossalai/shardformer/utils/utils.py @@ -0,0 +1,56 @@ +def hasattr_(obj, attr: str): + """ + Check whether the object has the multi sublevel attr + + Args: + obj: The object to check + attr: The multi level attr to check + """ + attrs = attr.split('.') + for a in attrs: + try: + obj = getattr(obj, a) + except AttributeError: + return False + return True + +def setattr_(obj, attr: str, value, ignore: bool=False): + """ + Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist + + Args: + obj: The object to set + attr: The multi level attr to set + value: The value to set + ignore: Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split('.') + for a in attrs[:-1]: + try: + obj = getattr(obj, a) + except AttributeError: + if ignore: + return + raise AttributeError(f"Object {obj} has no attribute {attr}") + setattr(obj, attrs[-1], value) + +def getattr_(obj, attr: str, ignore: bool=None): + """ + Get the object's multi sublevel attr + + Args: + obj: The object to set + attr: The multi level attr to set + ignore: Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split('.') + for a in attrs: + try: + obj = getattr(obj, a) + except AttributeError: + if ignore: + return None + raise AttributeError(f"Object {obj} has no attribute {attr}") + return obj \ No newline at end of file