From c1c672d0f0fcba484f294ad8550df59ee5448fdd Mon Sep 17 00:00:00 2001 From: wukong1992 Date: Thu, 15 Jun 2023 16:50:08 +0800 Subject: [PATCH] [shardformer] shardformer support t5 model (#3994) test t5 --- applications/Chat/coati/trainer/.sft.py.swp | Bin 0 -> 20480 bytes colossalai/shardformer/layer/layers.py | 8 +- colossalai/shardformer/policies/autopolicy.py | 9 + colossalai/shardformer/policies/basepolicy.py | 12 ++ colossalai/shardformer/policies/t5.py | 159 ++++++++++++++++++ colossalai/shardformer/shard/sharder.py | 11 +- colossalai/shardformer/shard/slicer.py | 6 +- colossalai/shardformer/utils/utils.py | 25 ++- requirements/requirements-test.txt | 1 + .../test_model/test_shard_t5.py | 99 +++++++++++ 10 files changed, 320 insertions(+), 10 deletions(-) create mode 100644 applications/Chat/coati/trainer/.sft.py.swp create mode 100644 colossalai/shardformer/policies/t5.py create mode 100644 tests/test_shardformer/test_model/test_shard_t5.py diff --git a/applications/Chat/coati/trainer/.sft.py.swp b/applications/Chat/coati/trainer/.sft.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..302cf2a775338fb4fcd6b9b12c1a8e80f3969a01 GIT binary patch literal 20480 zcmeHOU5q4E6)psY74VP5@FaJ89H$z3t7mp0aohBo-E~(9yMp_JSthfc>h9aq#Z*@{ zRn;uJ!vHZRh#HOZ#2YFJFB0QVSObsxWI{~Lk{BNh;)6b@tf-4W;QF2WQ+2DldsbdF zA-9unZdKiL?>+as=box_s;0Yq@0~~4UT@jK?*oo=(;FxI-#GnI=k?bd=l&oa`tc}; zsBW8IF4yNr58UI0{+tGnwG2>WBO9Z2IGI81(!sh@-T5aOh{vUW~J5 zmiR%$lV&=|yz$v>bg)&H0n5OZ7`V~dvwKBs^@Z9_cGHJWUa5htw+vVYECZGS%YbFT zGGH073|I!UFsSOzQumI2Ga|A7I&?>KiO*ng4|0C@keHBLF9S~i4+E<} z3%D7$h(y4zfBv-t%5*a`ZrO?|Z0V174>laU{aAzEycM%t>?or)eT z7)T}dwb^d(4(3OF7NniTk2X5XZoM_Fg3>&mWaCMedUQH1q-rrkF||?L=bgLq$?QqGObFGlw zsur=BXQ9tuXN* zT09OpAri(tKNK?cdNBm{(In*^F^V*UJ|D0FMx$9b)UIt9LaZCdA#DAyAEq3p115z* z{>`AYG)r^FmE&kK>WKXTk05G(MX;eL>PT*BWlX2cfHuQQt2;ICh7{ju+ zyrA2V(pmqIh0}G%TZ!m0! zuzJ@o<|mwylTIh&ahzlv_n``0Jis|OV;YhZYFtuXv)eOs7m2C0u1LIu6La(7W#S>G1>6z+*f z#61zckXD6kb!|=ILP}U%i`_kAn<|c4OtV$A62@4Zk{!!4@_f0?{%%}z zh}c@KCQ3RMT6H$ylGan^_JdwVH;sWm3A2=CG22`Ulcm-Q-*x*|y~VO2XG#`z@Wt%R zFbFxg4YI%wgL80=vP*JA`5^hSF}J*I<*`=R)DVR{MxhhdpCr_^V44oplT1EBv6jrm zLYMp9t-6In>cS1e&XpfS@1n=N8~JRPLiriptf+2;5EDXj_FAjFn-5K6S(muQS)X-u*-Q38O%Hk?52`wot8x1|xR_{*cYCBZT^NQ_SOZVw! zanc)lQRE$s>8_)Ckd-CA_rv0hN8c0j>twB9nOT-_SZH6mW=}9hM%)X1JmDWdd_pyL zRcYMl|6v-xGf3hQ^CnpkrYsnd3oM>j@#GPA`uxG5$CaF#cEEGtX}{LU_pEAE*sdDj zpz=|wzUza2Ri)H*WL1Xe-4qKm;z{b^gg(~RlBFB>Eb(}QN5MH6*NFMy$*>sco`lP> z`X{6yPY$3ucN_mSSki9eU&b}T8^sG+A39cQi1ijjdSs(oqVSidoG?T!l3S^x_z(|v zxl(UJ0sR=R1f+4{IX_E+?u71NeoE0trh5wgVA3!DOO z27ZE={t@69@Ht>7@G5lu4e%WB7*GRt0yKvuD(%lQU>UFsSOzQumVs>I)s>jiJTo%Io9T8B?Yr36Urjv+!=8{+FB~71M~<8-s$A*#%|0$B8Agnj(`}mV2&{9PSOKJJ^azp3@v(PnE z!oFvzi?B{4y2|MzYJiQ2?KN%6rMDy7Y}o~s2*;LnH)}M%TFW%rt+IY{$nyRcq1em~ zF7pyw>;PriS@Eig9mCdTo0XkxIsvt1$PMT)=8rgVmBZ54{$zwJq?-9#ATos!(2Ju* zOlloD$&4rDj-=CYOnH5*UxO5+=^C@#=FNgBHiJI1I)4ub(zU9NDNsGG2}&U?%K2M~ z+SS4Yg2J5gB)M>$NHb)e5Lw#Q_QrZG26NCeC8`7U!w>MbEwzq59?&SKNOkO z%6pM&i^`;E$toL3w-AzfzQ{w)NN-I?iWtKD+hn9b%7 literal 0 HcmV?d00001 diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index f5123885b..a9f3cf5ad 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -770,6 +770,7 @@ class Embedding1D(ParallelLayer): embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, + gather_output: bool = True, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -782,6 +783,7 @@ class Embedding1D(ParallelLayer): self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + self.gather_output = gather_output self.weight = Parameter( torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) @@ -832,8 +834,10 @@ class Embedding1D(ParallelLayer): def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + if self.gather_output: + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel return output diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 27fd09b45..d4425497b 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -43,6 +43,15 @@ def build_policies(): from .gpt2 import GPT2LMHeadModelPolicy auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy + + from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy + from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model + t5 = { + T5ForConditionalGeneration: T5ForConditionalGenerationPolicy, + T5EncoderModel: T5EncoderModelPolicy, + T5Model: T5ModelPolicy, + } + auto_policy_dict.update(t5) return auto_policy_dict diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index d55df59fd..ba3a97f1b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -80,6 +80,18 @@ class Dropout_Layer(Layer): p: str = None +@dataclass +class Embedding_Layer(Layer): + r""" + Class for col shard layer in tensor parrallel + + Args: + weight (str): The weight suffix of the layer + """ + weight: str = None + gather_output: bool = True + + class Policy(): r""" The base class for all the policies diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py new file mode 100644 index 000000000..7b013a378 --- /dev/null +++ b/colossalai/shardformer/policies/t5.py @@ -0,0 +1,159 @@ +from typing import Dict + +import torch.nn as nn +from torch.nn import Embedding +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5Block, + T5DenseActDense, + T5DenseGatedActDense, + T5LayerCrossAttention, + T5LayerFF, + T5LayerSelfAttention, + T5Model, + T5Stack, +) + +import colossalai.shardformer.layer.layers as col_nn + +from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer + + +class T5ModelPolicy(Policy): + + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + print('config heads', config.num_heads) + return { + T5Stack: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]), + T5Block: + Argument(attr_dict={}, param_funcs=[]), + T5LayerSelfAttention: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + T5LayerCrossAttention: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + T5Attention: + Argument(attr_dict={ + "d_model": config.d_model // world_size, + "n_heads": config.num_heads // world_size, + "inner_dim": config.num_heads * config.d_kv // world_size, + }, + param_funcs=[T5ModelPolicy.attn_layer]), + T5LayerFF: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + T5DenseGatedActDense: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]), + T5DenseActDense: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]), + } + + @staticmethod + def dense_gated_layer(): + return [ + Col_Layer( + suffix="wi_0", + weight="weight", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="wi_1", + weight="weight", + replace_layer=col_nn.Linear1D_Row, + ), + Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True) + ] + + @staticmethod + def dense_act_layer(): + return [ + Col_Layer( + suffix="wi", + weight="weight", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="wo", + weight="weight", + replace_layer=col_nn.Linear1D_Row, + ) + ] + + @staticmethod + def attn_layer(): + return [ + Col_Layer( + suffix="q", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="k", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="v", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="o", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Row, + ), + ] + + @staticmethod + def dropout(): + return [Dropout_Layer( + suffix="dropout", + p="p", + replace_layer=col_nn.Dropout1D, + )] + + @staticmethod + def embedding(): + return [ + Embedding_Layer( + suffix="block[0].layer[0].SelfAttention.relative_attention_bias", + weight="weight", + replace_layer=col_nn.Embedding1D, + gather_output=False, + ) + ] + + +from transformers import T5ForConditionalGeneration + + +class T5ForConditionalGenerationPolicy(T5ModelPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = T5ModelPolicy.argument_policy(config, world_size) + argument = { + T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head]) + } + argument.update(base_argument) + return argument + + @staticmethod + def lm_head(): + return [Col_Layer( + suffix="lm_head", + weight="weight", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + )] + + +from transformers import T5EncoderModel + + +class T5EncoderModelPolicy(T5ModelPolicy): + pass diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 95184cfe6..8f6514cb4 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -5,7 +5,7 @@ import torch.nn as nn from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer from ..utils.utils import getattr_, hasattr_, setattr_ from .shard_config import ShardConfig from .slicer import Slicer @@ -155,11 +155,11 @@ class ModelSharder(object): assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" if suffix_layer is None and ignore: continue - if isinstance(policy_layer, (Col_Layer, Row_Layer)): + if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)): weight = None bias = None weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None - bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None + bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -189,6 +189,11 @@ class ModelSharder(object): weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + elif replace_layer_cls.__name__ == "Embedding1D": + gather_output = policy_layer.gather_output + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + gather_output=gather_output) elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 0bf8f58b8..860533dca 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,9 +1,9 @@ import torch -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer from .shard_config import ShardConfig -dim_mapping = {Col_Layer: 0, Row_Layer: 1} +dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1} class Slicer(): @@ -43,6 +43,8 @@ class Slicer(): bias = self.slice_tensor(bias, 0, True, n_cast) elif policy_layer_cls == Row_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) + elif policy_layer_cls == Embedding_Layer: + weight = self.slice_tensor(weight, dim, False, n_cast) else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") if reversed: diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 2c02b6f69..05a6a3ae6 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -1,3 +1,22 @@ +import re + + +def get_obj_list_element(obj, a): + re_pattern = r'\[\d+\]' + prog = re.compile(re_pattern) + result = prog.search(a) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace('[', '') + matched_index = matched_index.replace(']', '') + a_ = a.replace(matched_brackets, '') + container_obj = getattr(obj, a_) + obj = container_obj[int(matched_index)] + else: + obj = getattr(obj, a) + return obj + + def hasattr_(obj, attr: str): r""" Check whether the object has the multi sublevel attr @@ -9,7 +28,7 @@ def hasattr_(obj, attr: str): attrs = attr.split('.') for a in attrs: try: - obj = getattr(obj, a) + obj = get_obj_list_element(obj, a) except AttributeError: return False return True @@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): attrs = attr.split('.') for a in attrs[:-1]: try: - obj = getattr(obj, a) + obj = get_obj_list_element(obj, a) except AttributeError: if ignore: return @@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False): attrs = attr.split('.') for a in attrs: try: - obj = getattr(obj, a) + obj = get_obj_list_element(obj, a) except AttributeError: if ignore: return None diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 6895113bc..50121a928 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -15,3 +15,4 @@ einops triton==2.0.0.dev20221202 git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 +SentencePiece diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py new file mode 100644 index 000000000..ca44f0b00 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -0,0 +1,99 @@ +import copy +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = T5Tokenizer.from_pretrained("t5-small") + + +def build_model(rank, world_size): + config = T5Config.from_pretrained("t5-small") + config.dropout_rate = 0 + org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda') + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, + ) + + org_model_for_shard = copy.deepcopy(org_model) + + sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + + input_ids = tokenizer("translate English to German: The house is wonderful.", + return_tensors="pt").input_ids.to('cuda') + #orgin model + org_model.eval() + org_output = org_model.generate(input_ids) + + #shard model + sharded_model.eval() + shard_output = sharded_model.generate(input_ids) + assert torch.allclose( + org_output[0], shard_output[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input_ids = tokenizer("translate English to German: The house is wonderful.", + return_tensors="pt").input_ids.to('cuda') + labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') + + #orgin model + org_model.train() + org_loss = org_model(input_ids=input_ids, labels=labels).loss + org_loss.backward() + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad + + #shard model + sharded_model.train() + shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss + shard_loss.backward() + shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_t5(): + spawn(check_t5, 2) + + +if __name__ == "__main__": + test_t5()