[shardformer] shardformer support t5 model (#3994)

test t5
pull/4157/head
wukong1992 2023-06-15 16:50:08 +08:00 committed by Frank Lee
parent 6b30dfb7ce
commit c1c672d0f0
10 changed files with 320 additions and 10 deletions

Binary file not shown.

View File

@ -770,6 +770,7 @@ class Embedding1D(ParallelLayer):
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@ -782,6 +783,7 @@ class Embedding1D(ParallelLayer):
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
@ -832,8 +834,10 @@ class Embedding1D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
return output

View File

@ -43,6 +43,15 @@ def build_policies():
from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
t5 = {
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
T5EncoderModel: T5EncoderModelPolicy,
T5Model: T5ModelPolicy,
}
auto_policy_dict.update(t5)
return auto_policy_dict

View File

@ -80,6 +80,18 @@ class Dropout_Layer(Layer):
p: str = None
@dataclass
class Embedding_Layer(Layer):
r"""
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
"""
weight: str = None
gather_output: bool = True
class Policy():
r"""
The base class for all the policies

View File

@ -0,0 +1,159 @@
from typing import Dict
import torch.nn as nn
from torch.nn import Embedding
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5Block,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Model,
T5Stack,
)
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer
class T5ModelPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
print('config heads', config.num_heads)
return {
T5Stack:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
T5Block:
Argument(attr_dict={}, param_funcs=[]),
T5LayerSelfAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
T5LayerCrossAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
T5Attention:
Argument(attr_dict={
"d_model": config.d_model // world_size,
"n_heads": config.num_heads // world_size,
"inner_dim": config.num_heads * config.d_kv // world_size,
},
param_funcs=[T5ModelPolicy.attn_layer]),
T5LayerFF:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
T5DenseGatedActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
T5DenseActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
}
@staticmethod
def dense_gated_layer():
return [
Col_Layer(
suffix="wi_0",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wi_1",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
),
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
]
@staticmethod
def dense_act_layer():
return [
Col_Layer(
suffix="wi",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wo",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
)
]
@staticmethod
def attn_layer():
return [
Col_Layer(
suffix="q",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="k",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="v",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="o",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def dropout():
return [Dropout_Layer(
suffix="dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)]
@staticmethod
def embedding():
return [
Embedding_Layer(
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
weight="weight",
replace_layer=col_nn.Embedding1D,
gather_output=False,
)
]
from transformers import T5ForConditionalGeneration
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = T5ModelPolicy.argument_policy(config, world_size)
argument = {
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
}
argument.update(base_argument)
return argument
@staticmethod
def lm_head():
return [Col_Layer(
suffix="lm_head",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)]
from transformers import T5EncoderModel
class T5EncoderModelPolicy(T5ModelPolicy):
pass

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
from ..utils.utils import getattr_, hasattr_, setattr_
from .shard_config import ShardConfig
from .slicer import Slicer
@ -155,11 +155,11 @@ class ModelSharder(object):
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
if suffix_layer is None and ignore:
continue
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
weight = None
bias = None
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
@ -189,6 +189,11 @@ class ModelSharder(object):
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
elif replace_layer_cls.__name__ == "Embedding1D":
gather_output = policy_layer.gather_output
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
gather_output=gather_output)
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))

View File

@ -1,9 +1,9 @@
import torch
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
from .shard_config import ShardConfig
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}
class Slicer():
@ -43,6 +43,8 @@ class Slicer():
bias = self.slice_tensor(bias, 0, True, n_cast)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
elif policy_layer_cls == Embedding_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
if reversed:

View File

@ -1,3 +1,22 @@
import re
def get_obj_list_element(obj, a):
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(a)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
a_ = a.replace(matched_brackets, '')
container_obj = getattr(obj, a_)
obj = container_obj[int(matched_index)]
else:
obj = getattr(obj, a)
return obj
def hasattr_(obj, attr: str):
r"""
Check whether the object has the multi sublevel attr
@ -9,7 +28,7 @@ def hasattr_(obj, attr: str):
attrs = attr.split('.')
for a in attrs:
try:
obj = getattr(obj, a)
obj = get_obj_list_element(obj, a)
except AttributeError:
return False
return True
@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
attrs = attr.split('.')
for a in attrs[:-1]:
try:
obj = getattr(obj, a)
obj = get_obj_list_element(obj, a)
except AttributeError:
if ignore:
return
@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False):
attrs = attr.split('.')
for a in attrs:
try:
obj = getattr(obj, a)
obj = get_obj_list_element(obj, a)
except AttributeError:
if ignore:
return None

View File

@ -15,3 +15,4 @@ einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece

View File

@ -0,0 +1,99 @@
import copy
import os
import random
import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
def build_model(rank, world_size):
config = T5Config.from_pretrained("t5-small")
config.dropout_rate = 0
org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
org_model_for_shard = copy.deepcopy(org_model)
sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda')
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
#orgin model
org_model.eval()
org_output = org_model.generate(input_ids)
#shard model
sharded_model.eval()
shard_output = sharded_model.generate(input_ids)
assert torch.allclose(
org_output[0], shard_output[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
#orgin model
org_model.train()
org_loss = org_model(input_ids=input_ids, labels=labels).loss
org_loss.backward()
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
#shard model
sharded_model.train()
shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss
shard_loss.backward()
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_t5():
spawn(check_t5, 2)
if __name__ == "__main__":
test_t5()