diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index e096c2b13..54cc63ba1 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -10,16 +10,26 @@ def build_policies(): """ auto_policy_dict = {} - from transformers.models.bert.modeling_bert import BertForMaskedLM + from transformers import BertForMaskedLM from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy - from transformers.models.bert.modeling_bert import BertForSequenceClassification + from transformers import BertForSequenceClassification from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + from transformers import GPT2Model + + from .gpt2 import GPT2Policy + auto_policy_dict[GPT2Model] = GPT2Policy + + from transformers import GPT2LMHeadModel + + from .gpt2 import GPT2LMHeadModelPolicy + auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy + return auto_policy_dict diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 2eb7eb29e..644d115a2 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,11 +1,9 @@ # part of code modified from https://github.com/tunib-ai/parallelformers -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Tuple, Type -import torch import torch.nn as nn -from transformers import AutoConfig @dataclass @@ -31,11 +29,18 @@ class Layer: bias (str): The bias suffix of the layer replace_layer (:class:`colosalai.nn`): The layer to replace the original layer ignore (bool): Whether to ignore this layer if it is not in the model + reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], + but in GPT2 `Conv1D` layer is [in, out] which is reversed. + n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, + but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and + each device should have a part of Q, K and V weight. """ weight: str = None bias: str = None replace_layer: Any = None ignore: bool = False + reversed: bool = False + n_cast: int = None @dataclass @@ -131,7 +136,7 @@ class Policy(): (OrignModel, CustomModel) in `CustomModel`, we can overwrite the forward and backward process """ - return () + return None @staticmethod def binding_policy() -> Dict: @@ -146,7 +151,7 @@ class Policy(): "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } """ - return NotImplementedError + return None @staticmethod def attn_in() -> List: @@ -209,4 +214,4 @@ class Policy(): Return: List[Layer]: List of layer object """ - return NotImplementedError + return None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ab77b29f7..89b32f065 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Tuple, Type import torch.nn as nn diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py new file mode 100644 index 000000000..44dc9c72f --- /dev/null +++ b/colossalai/shardformer/policies/gpt2.py @@ -0,0 +1,118 @@ +from typing import Any, Callable, Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + +import colossalai.shardformer.layer.layers as col_nn + +from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer + + +class GPT2Policy(Policy): + + @staticmethod + def argument_policy(config, world_size): + return { + GPT2Model: + Argument(attr_dict={}, param_funcs=[ + GPT2Policy.embedding, + ]), + GPT2Block: + Argument( + attr_dict={ + # 1. reduce hidden size + "attn.embed_dim": config.hidden_size // world_size, + "attn.split_size": config.hidden_size // world_size, + "crossattention.embed_dim": config.hidden_size // world_size, + "crossattention.split_size": config.hidden_size // world_size, + # 2. reduce number of heads + "attn.num_heads": config.num_attention_heads // world_size, + "crossattention.num_heads": config.num_attention_heads // world_size, + }, + param_funcs=[ + GPT2Policy.attn_in, + GPT2Policy.attn_out, + GPT2Policy.mlp_in, + GPT2Policy.mlp_out, + ]), + } + + @staticmethod + def attn_in() -> List: + return [ + Col_Layer(weight="attn.c_attn.weight", + bias="attn.c_attn.bias", + n_cast=3, + reversed=True, + replace_layer=col_nn.Linear1D_Col), + Col_Layer(weight="crossattention.c_attn.weight", + bias="crossattention.c_attn.bias", + n_cast=2, + reversed=True, + ignore=True, + replace_layer=col_nn.Linear1D_Col), + Col_Layer(weight="crossattention.q_attn.weight", + bias="crossattention.q_attn.bias", + reversed=True, + ignore=True, + replace_layer=col_nn.Linear1D_Col) + ] + + @staticmethod + def attn_out() -> List: + return [ + Row_Layer(weight="attn.c_proj.weight", + bias="attn.c_proj.bias", + reversed=True, + replace_layer=col_nn.Linear1D_Row), + Row_Layer(weight="crossattention.c_proj.weight", + bias="crossattention.c_proj.bias", + reversed=True, + ignore=True, + replace_layer=col_nn.Linear1D_Row) + ] + + @staticmethod + def mlp_in() -> List: + return [ + Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col), + ] + + @staticmethod + def mlp_out() -> List: + return [ + Row_Layer(weight="mlp.c_proj.weight", + bias="mlp.c_proj.bias", + reversed=True, + replace_layer=col_nn.Linear1D_Row) + ] + + @staticmethod + def embedding() -> List: + return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)] + + +from transformers import GPT2LMHeadModel + + +class GPT2LMHeadModelPolicy(GPT2Policy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = GPT2Policy.argument_policy(config, world_size) + argument = { + GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[ + GPT2LMHeadModelPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def unembedding() -> List: + return [ + Col_Layer(weight="lm_head.weight", + bias="lm_head.bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True) + ] diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 221866188..1ada75e06 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List import torch import torch.nn as nn +from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy @@ -35,10 +36,22 @@ class ModelSharder(object): self.model_config = self.model.config def shard(self) -> None: + self.reshape_embedding() self.inject_model(self.model) self.replace_layer(self.model) self.bind_layer(self.model) + def reshape_embedding(self,) -> None: + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model_config.vocab_size + world_size = self.shard_config.world_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + self.model_config = self.model.config + def inject_model( self, model: nn.Module, @@ -53,6 +66,8 @@ class ModelSharder(object): """ inject_policy = self.policy.inject_policy() + if inject_policy is None: + return org_model_cls = inject_policy[0] shard_model_cls = inject_policy[1] @@ -82,9 +97,9 @@ class ModelSharder(object): origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) + self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) - def reverse_replace_layer( + def traverse_replace_layer( self, layer: nn.Module, origin_cls: nn.Module, @@ -100,17 +115,12 @@ class ModelSharder(object): attr_dict (Dict): The attribute dict to modify policy_cls (:class:`Policy`): The policy class """ + if layer.__class__ == origin_cls: + for k, v in attr_dict.items(): + setattr_(layer, k, v, ignore=True) + self.shard_one_layer(layer, param_funcs) for name, child in layer.named_children(): - if child.__class__ == origin_cls: - # replac_layer = child - for k, v in attr_dict.items(): - setattr_(child, k, v, ignore=True) - # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) - # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, param_funcs) - continue - - self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) + self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer def shard_one_layer( @@ -126,7 +136,6 @@ class ModelSharder(object): param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class """ - # print(org_layer) for func in param_funcs: policy_layers = func() for policy_layer in policy_layers: @@ -136,9 +145,10 @@ class ModelSharder(object): bias_attr = policy_layer.bias replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore + n_cast = policy_layer.n_cast + reversed = policy_layer.reversed if policy_layer.__class__.__name__ == "Col_Layer": gather_output = policy_layer.gather_output - # print(gather_output) if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -161,13 +171,11 @@ class ModelSharder(object): layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) - # print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) # create new object to replace the origin layer if replace_layer_cls is not None: - # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") - if isinstance(getattr_(org_layer, layer_attr), nn.Linear): + if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)): if replace_layer_cls.__name__ == "Linear1D_Row": replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], @@ -235,6 +243,8 @@ class ModelSharder(object): model (:class:`torch.nn.Module`): The shard model """ binding_map = self.policy.binding_policy() + if binding_map is None: + return for k, v in binding_map.items(): param = getattr_(model, k) param = nn.Parameter(param) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 26053b9f7..6d35bd193 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -19,6 +19,8 @@ class Slicer(): weight: torch.Tensor, bias: torch.Tensor, policy_layer_cls: Layer, + n_cast: int = None, + reversed: bool = False, ): r""" Slice the weight and bias according to policy layer cls @@ -33,13 +35,18 @@ class Slicer(): """ if policy_layer_cls == Layer: return weight, bias - elif policy_layer_cls == Col_Layer: - weight = self.slice_tensor(weight, 1, False) + + dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) + # print(weight.shape, dim) + if policy_layer_cls == Col_Layer: + weight = self.slice_tensor(weight, dim, False, n_cast) bias = self.slice_tensor(bias, 0, True) elif policy_layer_cls == Row_Layer: - weight = self.slice_tensor(weight, 0, False) + weight = self.slice_tensor(weight, dim, False, n_cast) else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") + if reversed: + weight = weight.transpose(0, 1).contiguous() return weight, bias def slice_tensor( @@ -47,6 +54,7 @@ class Slicer(): tensor_in: torch.Tensor, dim: int, is_bias: bool, + n_cast: int = None, ) -> torch.Tensor: r""" Slice tensor according to the config @@ -59,14 +67,15 @@ class Slicer(): if tensor_in is None: return None if not is_bias: - return self.slice_2d(tensor_in, dim) + return self.slice_2d(tensor_in, dim, n_cast) else: - return self.slice_1d(tensor_in) + return self.slice_1d(tensor_in, n_cast) def slice_2d( self, tensor: torch.Tensor, dim: int, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the 2D tensor @@ -77,13 +86,14 @@ class Slicer(): """ assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" if dim == 0: - return self.slice_row(tensor) + return self.slice_row(tensor, n_cast) elif dim == 1: - return self.slice_col(tensor) + return self.slice_col(tensor, n_cast) def slice_1d( self, tensor: torch.Tensor, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the 1D tensor @@ -94,11 +104,19 @@ class Slicer(): Returns: :class:`torch.Tensor`: The sliced tensor """ - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + if n_cast is None: + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + else: + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + chunk_list = [ + tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) + ] + return torch.cat(chunk_list, dim=0).contiguous() def slice_col( self, tensor: torch.Tensor, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the tensor in column @@ -110,11 +128,19 @@ class Slicer(): :class:`torch.Tensor`: The sliced tensor """ - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + if n_cast is None: + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + else: + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + chunk_list = [ + tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) + ] + return torch.cat(chunk_list, dim=0).contiguous() def slice_row( self, tensor: torch.Tensor, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the tensor in column @@ -125,4 +151,11 @@ class Slicer(): Returns: :class:`torch.Tensor`: The sliced tensor """ - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + if n_cast is None: + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + else: + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) + chunk_list = [ + tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) + ] + return torch.cat(chunk_list, dim=1).contiguous() diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index b896fd4a4..e2d5a94c7 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -6,24 +6,28 @@ import torch.nn as nn from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler import colossalai from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.utils import get_current_device, print_rank_0 os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") def get_args(): parser = colossalai.get_default_parser() parser.add_argument("--mode", type=str, default='inference') parser.add_argument("--save_model", action='store_true') + parser.add_argument("--model", type=str, default='bert-base-uncased') return parser.parse_args() -def load_data(): +def load_data(args): + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + # tokenizer.pad_token_id = 0 datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') # datasets=load_dataset("yelp_review_full") tokenized_datasets = datasets.map( @@ -42,18 +46,23 @@ def load_data(): def inference(model: nn.Module, args): - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + print(model) + # print(model.wte.weight.shape) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.pad_token_id = 0 token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") model.eval() model.to("cuda") outputs = model(**inputs) - print(outputs) + print(outputs[0]) def train(model: nn.Module, args, num_epoch: int = 3): - train_dataloader, eval_dataloader = load_data() + train_dataloader, eval_dataloader = load_data(args) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) num_training = num_epoch * len(train_dataloader) progress_bar = tqdm(range(num_training)) @@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3): if __name__ == "__main__": args = get_args() - model = BertForMaskedLM.from_pretrained("bert-base-uncased") colossalai.launch_from_torch(config=args.config) + if args.model == 'bert-base-uncased': + model = BertForMaskedLM.from_pretrained("bert-base-uncased") + elif args.model == 'gpt2': + model = GPT2LMHeadModel.from_pretrained("gpt2") + else: + raise AttributeError("model not supported") shard_config = ShardConfig( rank=int(str(get_current_device()).split(':')[-1]), world_size=int(os.environ['WORLD_SIZE']),