diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 50c927380..77c2af8d1 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import ShardConfig, shard_model +from .shard import ShardConfig, ShardFormer diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index e864719ac..6239397b7 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -1,5 +1,7 @@ import torch.nn as nn +from .basepolicy import Policy + def build_policies(): r""" @@ -41,47 +43,25 @@ def build_policies(): auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy from transformers.models.llama.modeling_llama import LlamaModel - from .llama import LlamaPolicy - auto_policy_dict[LlamaModel] = LlamaPolicy - - from transformers import LlamaForSequenceClassification - - from .llama import LlamaForSequenceClassificationPolicy - auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy - - from transformers import LlamaForCausalLM - - from .llama import LlamaForCausalLMPolicy - auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy - - from transformers import BertForMultipleChoice - - from .bert import BertForMultipleChoicePolicy - auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy - - from transformers import GPT2Model - - from .gpt2 import GPT2Policy - auto_policy_dict[GPT2Model] = GPT2Policy - - from transformers import GPT2LMHeadModel - - from .gpt2 import GPT2LMHeadModelPolicy - auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy - - from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy - from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model - t5 = { - T5ForConditionalGeneration: T5ForConditionalGenerationPolicy, - T5EncoderModel: T5EncoderModelPolicy, - T5Model: T5ModelPolicy, - } - auto_policy_dict.update(t5) + # from .llama import LlamaPolicy + # auto_policy_dict[LlamaModel] = LlamaPolicy + # from transformers import LlamaForSequenceClassification + # from .llama import LlamaForSequenceClassificationPolicy + # auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy + # from transformers import LlamaForCausalLM + # from .llama import LlamaForCausalLMPolicy + # auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy + # from transformers import GPT2Model + # from .gpt2 import GPT2Policy + # auto_policy_dict[GPT2Model] = GPT2Policy + # from transformers import GPT2LMHeadModel + # from .gpt2 import GPT2LMHeadModelPolicy + # auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy return auto_policy_dict -def get_autopolicy(model: nn.Module): +def get_autopolicy(model: nn.Module) -> Policy: r""" Return the auto policy for the model @@ -97,7 +77,7 @@ def get_autopolicy(model: nn.Module): raise NotImplementedError( f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" ) - return policy + return policy() # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index ba3a97f1b..80ea7a252 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,102 +1,65 @@ # part of code modified from https://github.com/tunib-ai/parallelformers +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch.nn as nn +from ..shard.shard_config import ShardConfig -@dataclass -class Argument: - r""" - The argument class for the policy - Args: - attr_dict (Dict[str, Any]): The dict for the param setting - param_funcs (:class:`List[Callable]`): The list for the param functions - """ - attr_dict: Dict[str, Any] - param_funcs: List[Callable] +class ParallelModule(): + + def __init__(self): + pass @dataclass -class Layer: +class SubModuleReplacementDescription: r""" - The layer object for the policy + Describe how a submodule will be replaced - Args: - suffix: (str): the suffix of the layer. - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - ignore (bool): Whether to ignore this layer if it is not in the model - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], - but in GPT2 `Conv1D` layer is [in, out] which is reversed. - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, - but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and - each device should have a part of Q, K and V weight. + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. """ - suffix: str = None - replace_layer: Any = None - ignore: bool = False - reversed: bool = False - n_cast: int = None + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None @dataclass -class Col_Layer(Layer): +class ModulePolicyDescription: r""" - Class for col shard layer in tensor parrallel + Describe how the attributes and parameters will be transformed in a policy - Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer - gather_output (bool): Whether to gather the output of the layer + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function + must receive two arguments: module, process_group. One example is + + ```python + def example_replace_weight(module: torch.nn.Module, process_group): + weight = module.weight + new_weight = shard_rowwise(weight, process_group) + module.weight = torch.nn.Parameter(new_weight) + ``` + + sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies + the module to be replaced and the target module used to replacement """ - weight: str = None - bias: str = None - gather_output: bool = False + attribute_replacement: Dict[str, Any] + param_replacement: List[Callable] + sub_module_replacement: List[SubModuleReplacementDescription] -@dataclass -class Row_Layer(Layer): - r""" - Class for col shard layer in tensor parrallel - - Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer - """ - weight: str = None - bias: str = None - - -@dataclass -class Dropout_Layer(Layer): - r""" - Class for dropout layer in tensor parrallel - - Args: - p (str): The dropout rate suffix of the layer - """ - p: str = None - - -@dataclass -class Embedding_Layer(Layer): - r""" - Class for col shard layer in tensor parrallel - - Args: - weight (str): The weight suffix of the layer - """ - weight: str = None - gather_output: bool = True - - -class Policy(): +class Policy(ABC): r""" The base class for all the policies + For each different model, it should have a different policy class, like BertPolicy for Bert Model or OPTPolicy for OPT model. + AutoPolicy: Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, @@ -111,137 +74,75 @@ class Policy(): """ - @staticmethod - def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]: + def __init__(self) -> None: + self.model = None + + def set_model(self, model: nn.Module) -> None: + r""" + Set model as an attribute of the Policy object so that we can access the model's attributes. + + Args: + model (:class:`nn.Module`): The model to be perform + """ + self.model = model + + @abstractmethod + def preprocess(self, shard_config: ShardConfig = None) -> nn.Module: + r""" + Perform some preprocessing of the model, like reshaping the embedding layer + """ + + @abstractmethod + def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Return the dict for the modify policy, the key is the original layer class and the value is the argument for the modify layer - Args: - model_config (:class:`tansformer.Config`): The config of transformer model - world_size (int)): The world size of sharding model - Return: Dict for the modify policy, :: { - origin layer class1 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, + origin layer class1 (nn.Module): ModulePolicyDescription( + attribute_replacement = { + "attribute1": value1, + "attribute2": value2, ... }, - param_funcs = [ - staticmethod1, - staticmethod2, + param_replacement = [ + function1, + function2, + ... + ], + sub_module_replacement = [ + `SubModuleReplacementDescription` description1, + `SubModuleReplacementDescription` description2, ... ] ), - origin layer class2 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] + origin layer class2 (nn.Module): ModulePolicyDescription( + ... ), ... } - """ - raise NotImplementedError - @staticmethod - def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: + @abstractmethod + def new_model_class(self) -> Union[Type[nn.Module], None]: r""" - Return the dict for the inject model + Return the new model class for the new model, None means no need to modify the model class Return: - The injected model, key is the original model and value is the new shardmodel - :: - (OrignModel, CustomModel) - in `CustomModel`, we can overwrite the forward and backward process - """ - return None + New model class - @staticmethod - def binding_policy() -> Union[Dict[str, str], None]: + E.g. + ``` + return BertModel_ + ``` + """ + + @abstractmethod + def postprocess(self) -> nn.Module: r""" - Return the dict for the binding model, None means no need to bind - - Return: - This method should return the binding relationship for some layers share the weight or bias, - the key and value is the suffix of the weight or bias of the model - :: - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } + Perform some postprocessing of the model, like binding the weight of embedding layer with + the classifier layer """ - return None - - @staticmethod - def attn_in() -> Union[List, None]: - r""" - Attention qkv layer - In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be - ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters - in ``Layer`` object can refer to the ``Layer`` class. - - Returns: - List[Layer]: List of layer object, each layer is the new - """ - return None - - @staticmethod - def attn_out() -> Union[List, None]: - r""" - Attention output projection layer - - Returns: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def mlp_in() -> Union[List, None]: - r""" - h -> 4h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def mlp_out() -> Union[List, None]: - r""" - 4h -> h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def embedding() -> Union[List, None]: - r""" - Partially slice the embedding layer - - Return: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def unembedding() -> Union[List, None]: - r""" - Partially slice the embedding layer, None means there is no unembedding layer - - Return: - List[Layer]: List of layer object - """ - return None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ba2266353..f3431c386 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,220 +1,77 @@ -from typing import Any, Callable, Dict, List, Tuple, Type - import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead import colossalai.shardformer.layer.layers as col_nn -from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer +from ..shard.shard_config import ShardConfig +from ..utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class ParallelModule(): + + def __init__(self): + pass class BertPolicy(Policy): - @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + def preprocess(self, shard_config: ShardConfig = None): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self, shard_config: ShardConfig = None): return { BertLayer: - Argument( - attr_dict={ + ModulePolicyDescription( + attribute_replacement={ # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, + "attention.self.all_head_size": + self.model.config.hidden_size // shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": + self.model.config.hidden_size // shard_config.tensor_parallel_size, # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + "attention.self.num_attention_heads": + self.model.config.num_attention_heads // shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": + self.model.config.num_attention_heads // shard_config.tensor_parallel_size, }, - param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]), - BertEmbeddings: - Argument( - attr_dict={ - # 1. shard vocab size - "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, - }, - param_funcs=[ - BertPolicy.embedding, - ]), + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=ParallelModule, + ), + ]) } - @staticmethod - def attn_in(): - return [ - Col_Layer( - suffix="attention.self.query", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="attention.self.key", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="attention.self.value", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Dropout_Layer( - suffix="attention.self.dropout", - p="p", - replace_layer=col_nn.Dropout1D, - ), - Col_Layer( - suffix="crossattention.self.query", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - Col_Layer( - suffix="crossattention.self.key", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - Col_Layer( - suffix="crossattention.self.value", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - ] - - @staticmethod - def attn_out(): - return [ - Row_Layer( - suffix="attention.output.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ), - Dropout_Layer( - suffix="attention.output.dropout", - p="p", - replace_layer=col_nn.Dropout1D, - ), - Row_Layer( - suffix="crossattention.output.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ignore=True, - ), - ] - - @staticmethod - def mlp_in(): - return [ - Col_Layer( - suffix="intermediate.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - ] - - @staticmethod - def mlp_out(): - return [ - Row_Layer( - suffix="output.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ), - Dropout_Layer( - suffix="output.dropout", - p="p", - replace_layer=col_nn.Dropout1D, - ) - ] - - @staticmethod - def embedding(): - return [Col_Layer( - suffix="word_embeddings", - weight="weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - )] - - @staticmethod - def unembedding(): - return [ - Col_Layer( - suffix="decoder", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ) - ] - - -# BertModel -class BertModelPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) - - -# BertForPretraining -class BertForPretrainingPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def inject_policy(): + def new_model_class(self): + # do nothing return None - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - - -# BertForMaskedLM -from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model class BertForMaskedLMPolicy(BertPolicy): - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def inject_policy(): - # return (BertForMaskedLM, BertForMaskedLM_) - return None - - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } + def __init__(self) -> None: + super().__init__() # BertLMHeadModel @@ -231,36 +88,5 @@ class BertLMHeadModelPolicy(BertPolicy): argument.update(base_argument) return argument - @staticmethod - def inject_policy(): - return None - - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - - -# BertForNextSentencePrediction -class BertForNextSentencePredictionPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) - - -# BertForSequenceClassification -class BertForSequenceClassificationPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) - - -# BertForMultipleChoice -class BertForMultipleChoicePolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index d5f70163a..7abdd45ec 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,5 @@ from .shard_config import ShardConfig -from .sharder import ModelSharder, shard_model -from .slicer import Slicer +from .sharder import ModelSharder +from .shardformer import ShardFormer -__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer'] +__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer'] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 96c287577..53999529d 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import List, Literal __all__ = ['ShardConfig'] @@ -9,10 +10,18 @@ class ShardConfig: The config for sharding the huggingface model Args: - rank (int): The rank of local process - world_size (int): The world size of the distributed process + data_parallel_size (int): The size of data parallel + tensor_parallel_size (int): The size of tensor parallel + pipeline_parallel_size (int): The size of pipeline parallel + tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d'] + inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model + will not calculate the loss and just return the output. gather_output (bool): Whether to gather the output of the model of the last layer """ - rank: int = None - world_size: int = None + data_parallel_size: int + tensor_parallel_size: int + + pipeline_parallel_size: int + tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] + inference_only: bool = True gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 7ef0c37a4..8eee3c6a3 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -4,11 +4,12 @@ import torch import torch.nn as nn from transformers.pytorch_utils import Conv1D +from colossalai.cluster.process_group_manager import ProcessGroupManager + from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer -from ..utils.utils import getattr_, hasattr_, setattr_ +from ..policies.basepolicy import Policy +from ..utils.utils import setattr_ from .shard_config import ShardConfig -from .slicer import Slicer __all__ = ['ModelSharder', 'shard_model'] @@ -28,20 +29,23 @@ class ModelSharder(object): model: nn.Module, policy: Policy, shard_config: ShardConfig = None, # TODO - ) -> None: + pg_manager: ProcessGroupManager = None) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy - self.slicer = Slicer(shard_config) self.shard_config = shard_config - self.model_config = self.model.config + self.pg_manager = pg_manager def shard(self) -> None: - self.reshape_embedding() - self.inject_model(self.model) - self.replace_layer(self.model) - self.bind_layer(self.model) + r""" + Shard the model according to the policy + """ + self.policy.set_model(self.model) + self.preprocess() + self.replace_model_class() + self.replace_module() + self.postprocess() - def reshape_embedding(self,) -> None: + def reshape_embedding(self) -> None: r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ @@ -52,10 +56,13 @@ class ModelSharder(object): self.model.resize_token_embeddings(new_vocab_size) self.model_config = self.model.config - def inject_model( - self, - model: nn.Module, - ) -> None: + def preprocess(self) -> None: + self.model = self.policy.preprocess(self.shard_config) + + def postprocess(self) -> None: + self.model = self.policy.postprocess() + + def replace_model_class(self,) -> None: r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model @@ -64,49 +71,43 @@ class ModelSharder(object): :: BertForMaskedLM.forward -> BertForMaskedLM_.forward """ - inject_policy = self.policy.inject_policy() - if inject_policy is None: + new_model_class = self.policy.new_model_class() + if new_model_class is None: return - if inject_policy is None: - return - org_model_cls = inject_policy[0] - shard_model_cls = inject_policy[1] + for key in new_model_class.__dict__.keys(): + if hasattr(self.model.__class__, key): + setattr( + self.model.__class__, + key, + getattr(new_model_class, key), + ) - if model.__class__ == org_model_cls: - for key in shard_model_cls.__dict__.keys(): - if hasattr(model.__class__, key): - setattr( - model.__class__, - key, - getattr(shard_model_cls, key), - ) - else: - raise NotImplementedError(f"{model.__class__} is not implemented so far") - - def replace_layer( - self, - model: nn.Module, - ) -> None: + def replace_module(self,) -> None: r""" - Replace the layer according to the policy, and replace the layer one by one + Replace the module according to the policy, and replace the module one by one Args: - model (:class:`torch.nn.Module`): The layer to shard + model (:class:`torch.nn.Module`): The model to shard """ - argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) - for argument_policy in argument_policies.items(): - origin_layer_cls = argument_policy[0] - attr_dict = argument_policy[1].attr_dict - param_funcs = argument_policy[1].param_funcs - self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) + print(self.policy) + module_descriptions = self.policy.module_policy(self.shard_config) + print(f"*******{module_descriptions}") + for module_description in module_descriptions.items(): + origin_layer_cls = module_description[0] + attr_replacement = module_description[1].attribute_replacement + param_replacement = module_description[1].param_replacement + sub_module_replacement = module_description[1].sub_module_replacement + self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement, + sub_module_replacement) - def traverse_replace_layer( + def _recursive_replace_layer( self, - layer: nn.Module, + module: nn.Module, origin_cls: nn.Module, - attr_dict: Dict[str, Any], - param_funcs: List[Callable], + attr_replacement: Dict[str, Any], + param_replacement: List[Callable], + sub_module_replacement: List[Callable], ) -> None: r""" Reverse the replace layer operation @@ -114,21 +115,52 @@ class ModelSharder(object): Args: layer (:class:`torch.nn.Module`): The object of layer to shard origin_cls (:class:`transformers.model`): The origin layer class - attr_dict (Dict): The attribute dict to modify - policy_cls (:class:`Policy`): The policy class + attr_replacement (Dict): The attribute dict to modify + param_replacement (List[Callable]): The function list to get parameter shard information in polic + sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy """ - if layer.__class__ == origin_cls: - for k, v in attr_dict.items(): - setattr_(layer, k, v, ignore=True) - self.shard_one_layer(layer, param_funcs) - for name, child in layer.named_children(): - self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs) - return layer + if module.__class__ == origin_cls: + self._replace_attr(module, attr_replacement) + self._replace_param(module, param_replacement) + self._replace_sub_module(module, sub_module_replacement) + for name, child in module.named_children(): + self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, + sub_module_replacement) - def shard_one_layer( + def _replace_attr( + self, + module: nn.Module, + attr_replacement: Dict[str, Any], + ) -> None: + r""" + Replace the attribute of the layer + + Args: + layer (:class:`torch.nn.Module`): The object of layer to shard + attr_replacement (Dict): The attribute dict to modify + """ + for k, v in attr_replacement.items(): + setattr_(module, k, v, ignore=True) + + def _replace_param( + self, + module: nn.Module, + param_replacement: List[Callable], + ) -> None: + r""" + Replace the parameter of the layer + + Args: + layer (:class:`torch.nn.Module`): The object of layer to shard + param_replacement (List[Callable]): The function list to get parameter shard information in policy + """ + # TODO: support parameter shard + pass + + def _replace_sub_module( self, org_layer: nn.Module, - param_funcs: List[Callable], + sub_module_replacement: List[Callable], ) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -138,145 +170,14 @@ class ModelSharder(object): param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class """ - for func in param_funcs: - policy_layers = func() - for policy_layer in policy_layers: - suffix = policy_layer.suffix - replace_layer_cls = policy_layer.replace_layer - ignore = policy_layer.ignore - reversed = policy_layer.reversed - n_cast = policy_layer.n_cast + for description in sub_module_replacement: + suffix = description.suffix + target_module = description.target_module + kwargs = description.kwargs - assert replace_layer_cls is not None, 'replace_layer should not be None' + assert target_module is not None, 'target_module should not be None' - # create new object to replace the origin layer - # Linear - suffix_layer = getattr_(org_layer, suffix, ignore=True) - assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" - if suffix_layer is None and ignore: - continue - if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)): - weight = None - bias = None - weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None - bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None - - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - else: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") - - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - else: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - - # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) - - if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], - weight.shape[0], - bias=False if bias is None else True) - elif replace_layer_cls.__name__ == "Linear1D_Col": - gather_output = policy_layer.gather_output and self.shard_config.gather_output - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - bias=False if bias is None else True, - gather_output=gather_output) - elif replace_layer_cls.__name__ == "Embedding1D": - gather_output = policy_layer.gather_output - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - gather_output=gather_output) - elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], - getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) - # setattr_(org_layer, suffix, replace_layer, ignore=ignore) - # self.set_param(replace_layer, weight, bias) - else: - raise NotImplementedError( - f"Replacing to {replace_layer_cls.__name__} is not implemented so far") - setattr_(org_layer, suffix, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - # dropout - elif isinstance(policy_layer, Dropout_Layer): - p_attr = suffix + '.' + policy_layer.p - p = getattr_(org_layer, p_attr, ignore=True) - replace_layer = replace_layer_cls(p) - setattr_(org_layer, suffix, replace_layer, ignore=ignore) - else: - raise NotImplementedError( - f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far") - - def set_param(self, - layer: Any, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - layer_attr: str = "") -> None: - r""" - Reset the weight and bias of the layer object - - Args: - layer (:class:`torch.nn.Module`): The layer object - layer_attr (str): The attribute name of the layer - weight (:class:`torch.Tensor`): The weight of the layer - bias (:class:`torch.Tensor`): The bias of the layer - """ - assert weight is not None or bias is not None - if weight is not None: - setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous())) - self.set_layer_size(layer, layer_attr, weight.shape) - if bias is not None: - setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous())) - - def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: - r""" - Set the layer attribute - - Args: - layer (:class:`torch.nn.Module`): The layer object - layer_attr (str): The attribute name of the layer - size (:class:`torch.Size`): The size of the tensor - """ - # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features - attrs = ["out_features", "in_features"] - for i, attr in enumerate(attrs): - if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) - - def bind_layer(self, model: nn.Module) -> None: - r""" - Bind the layer according to the binding policy - - Args: - model (:class:`torch.nn.Module`): The shard model - """ - binding_map = self.policy.binding_policy() - if binding_map is None: - return - for k, v in binding_map.items(): - param = getattr_(model, k) - param = nn.Parameter(param) - setattr_(model, k, param) - setattr_(model, v, param) - - -def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None): - r""" - The function is used to shard the PyTorch model. - - Args: - model (`torch.nn.Model`): the origin huggingface model - shard_config (`ShardConfig`): the config for distribute information - policy (`Policy`): the custom policy for sharding - """ - # TODO: init shard_config automatically - sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) - sharder.shard() - return model + # TODO: integrate with new layer + # replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) + replace_layer = None + setattr_(org_layer, suffix, replace_layer) diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py new file mode 100644 index 000000000..5313dfecb --- /dev/null +++ b/colossalai/shardformer/shard/shardformer.py @@ -0,0 +1,77 @@ +import torch.nn as nn +from torch.utils.data import Dataset + +from colossalai.cluster import DistCoordinator, ProcessGroupManager + +from ..policies.basepolicy import Policy +from .shard_config import ShardConfig +from .sharder import ModelSharder + + +class ShardFormer: + """ + Parallelize model based on the given config and policy + + Example: + + ```python + from colossalai.shardformer import ShardFormer, ShardConfig + from transformers import BertForMaskedLM + import colossalai + import torch + + colossalai.launch_from_torch(config={}) + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig( + tensor_parallel_size=2, + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_mode='1d', + inference_only=True, + gather_output=True + ) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + model = shard_former.shard_model(org_model) + ``` + """ + + def __init__(self, shard_config: ShardConfig): + """ + Do two things: + 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 2. serve as a store for + """ + self.coordinator = DistCoordinator() + self.shard_config = shard_config + self.pg_manager = None + + def init_distributed(self) -> ProcessGroupManager: + """ + Initialize the distributed process group according to the + """ + pg_manager = ProcessGroupManager() + if (self.shard_config.tensor_parallel_mode == '1d'): + pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size)) + self.pg_manager = pg_manager + return pg_manager + + def shard_model(self, model: nn.Module, policy: Policy = None): + r""" + The function is used to shard the PyTorch model. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + """ + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager) + sharder.shard() + return model + + def shard_dataset(self, dataset: Dataset): + """ + Shard dataset for DP + """ + pass diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py deleted file mode 100644 index 860533dca..000000000 --- a/colossalai/shardformer/shard/slicer.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch - -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer -from .shard_config import ShardConfig - -dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1} - - -class Slicer(): - - def __init__( - self, - shardconfig: ShardConfig #TODO - ) -> None: - self.shardconfig = shardconfig - - def slice_weight_bias( - self, - weight: torch.Tensor, - bias: torch.Tensor, - policy_layer_cls: Layer, - n_cast: int = None, - reversed: bool = False, - ): - r""" - Slice the weight and bias according to policy layer cls - ``Layer`` -> do nothing - ``Col_Layer`` -> slice the weight and bias along dim 1 - ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias - - Args: - weight (:class:`torch.nn.Module`): The weight of the layer - bias: (:class:`torch.nn.Module`): The bias of the layer - policy_layer_class (:class:`Policy`): The class represent how to slice the tensor - """ - if policy_layer_cls in [Layer, Dropout_Layer]: - return weight, bias - - dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) - # print(weight.shape, dim) - if policy_layer_cls == Col_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - bias = self.slice_tensor(bias, 0, True, n_cast) - elif policy_layer_cls == Row_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - elif policy_layer_cls == Embedding_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - else: - raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") - if reversed: - weight = weight.transpose(0, 1).contiguous() - return weight, bias - - def slice_tensor( - self, - tensor_in: torch.Tensor, - dim: int, - is_bias: bool, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice tensor according to the config - - Args: - tensor_in (:class:`torch.Tensor`): The tensor to slice - dim (int): The dimension to slice - is_bias (bool): Whether the tensor is bias - """ - if tensor_in is None: - return None - if not is_bias: - return self.slice_2d(tensor_in, dim, n_cast) - else: - return self.slice_1d(tensor_in, n_cast) - - def slice_2d( - self, - tensor: torch.Tensor, - dim: int, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the 2D tensor - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - dim (int): The dimension to slice - """ - assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" - if dim == 0: - return self.slice_row(tensor, n_cast) - elif dim == 1: - return self.slice_col(tensor, n_cast) - - def slice_1d( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the 1D tensor - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=0).contiguous() - - def slice_col( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the tensor in column - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=1).contiguous() - - def slice_row( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the tensor in column - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=0).contiguous() diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py index e69de29bb..b50e7b2f6 100644 --- a/colossalai/shardformer/utils/__init__.py +++ b/colossalai/shardformer/utils/__init__.py @@ -0,0 +1 @@ +from .utils import getattr_, hasattr_, setattr_