mirror of https://github.com/hpcaitech/ColossalAI
parent
6b30dfb7ce
commit
c1c672d0f0
Binary file not shown.
|
@ -770,6 +770,7 @@ class Embedding1D(ParallelLayer):
|
|||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = True,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
|
@ -782,6 +783,7 @@ class Embedding1D(ParallelLayer):
|
|||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
@ -832,8 +834,10 @@ class Embedding1D(ParallelLayer):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -43,6 +43,15 @@ def build_policies():
|
|||
|
||||
from .gpt2 import GPT2LMHeadModelPolicy
|
||||
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||
|
||||
from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
|
||||
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
|
||||
t5 = {
|
||||
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
|
||||
T5EncoderModel: T5EncoderModelPolicy,
|
||||
T5Model: T5ModelPolicy,
|
||||
}
|
||||
auto_policy_dict.update(t5)
|
||||
|
||||
return auto_policy_dict
|
||||
|
||||
|
|
|
@ -80,6 +80,18 @@ class Dropout_Layer(Layer):
|
|||
p: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Embedding_Layer(Layer):
|
||||
r"""
|
||||
Class for col shard layer in tensor parrallel
|
||||
|
||||
Args:
|
||||
weight (str): The weight suffix of the layer
|
||||
"""
|
||||
weight: str = None
|
||||
gather_output: bool = True
|
||||
|
||||
|
||||
class Policy():
|
||||
r"""
|
||||
The base class for all the policies
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import Embedding
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5Block,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Model,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class T5ModelPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
print('config heads', config.num_heads)
|
||||
return {
|
||||
T5Stack:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
|
||||
T5Block:
|
||||
Argument(attr_dict={}, param_funcs=[]),
|
||||
T5LayerSelfAttention:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
T5LayerCrossAttention:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
T5Attention:
|
||||
Argument(attr_dict={
|
||||
"d_model": config.d_model // world_size,
|
||||
"n_heads": config.num_heads // world_size,
|
||||
"inner_dim": config.num_heads * config.d_kv // world_size,
|
||||
},
|
||||
param_funcs=[T5ModelPolicy.attn_layer]),
|
||||
T5LayerFF:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
T5DenseGatedActDense:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
|
||||
T5DenseActDense:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def dense_gated_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="wi_0",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="wi_1",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def dense_act_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="wi",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="wo",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="q",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="k",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="v",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="o",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def dropout():
|
||||
return [Dropout_Layer(
|
||||
suffix="dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def embedding():
|
||||
return [
|
||||
Embedding_Layer(
|
||||
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Embedding1D,
|
||||
gather_output=False,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = T5ModelPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def lm_head():
|
||||
return [Col_Layer(
|
||||
suffix="lm_head",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
)]
|
||||
|
||||
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
|
||||
class T5EncoderModelPolicy(T5ModelPolicy):
|
||||
pass
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
|
||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||
from .shard_config import ShardConfig
|
||||
from .slicer import Slicer
|
||||
|
@ -155,11 +155,11 @@ class ModelSharder(object):
|
|||
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
||||
if suffix_layer is None and ignore:
|
||||
continue
|
||||
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
|
||||
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
|
||||
weight = None
|
||||
bias = None
|
||||
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
||||
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
|
||||
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
|
@ -189,6 +189,11 @@ class ModelSharder(object):
|
|||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "Embedding1D":
|
||||
gather_output = policy_layer.gather_output
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
|
||||
from .shard_config import ShardConfig
|
||||
|
||||
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
|
||||
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}
|
||||
|
||||
|
||||
class Slicer():
|
||||
|
@ -43,6 +43,8 @@ class Slicer():
|
|||
bias = self.slice_tensor(bias, 0, True, n_cast)
|
||||
elif policy_layer_cls == Row_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
elif policy_layer_cls == Embedding_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
else:
|
||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||
if reversed:
|
||||
|
|
|
@ -1,3 +1,22 @@
|
|||
import re
|
||||
|
||||
|
||||
def get_obj_list_element(obj, a):
|
||||
re_pattern = r'\[\d+\]'
|
||||
prog = re.compile(re_pattern)
|
||||
result = prog.search(a)
|
||||
if result:
|
||||
matched_brackets = result.group()
|
||||
matched_index = matched_brackets.replace('[', '')
|
||||
matched_index = matched_index.replace(']', '')
|
||||
a_ = a.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, a_)
|
||||
obj = container_obj[int(matched_index)]
|
||||
else:
|
||||
obj = getattr(obj, a)
|
||||
return obj
|
||||
|
||||
|
||||
def hasattr_(obj, attr: str):
|
||||
r"""
|
||||
Check whether the object has the multi sublevel attr
|
||||
|
@ -9,7 +28,7 @@ def hasattr_(obj, attr: str):
|
|||
attrs = attr.split('.')
|
||||
for a in attrs:
|
||||
try:
|
||||
obj = getattr(obj, a)
|
||||
obj = get_obj_list_element(obj, a)
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
|
|||
attrs = attr.split('.')
|
||||
for a in attrs[:-1]:
|
||||
try:
|
||||
obj = getattr(obj, a)
|
||||
obj = get_obj_list_element(obj, a)
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return
|
||||
|
@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False):
|
|||
attrs = attr.split('.')
|
||||
for a in attrs:
|
||||
try:
|
||||
obj = getattr(obj, a)
|
||||
obj = get_obj_list_element(obj, a)
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return None
|
||||
|
|
|
@ -15,3 +15,4 @@ einops
|
|||
triton==2.0.0.dev20221202
|
||||
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
|
||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||
SentencePiece
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
|
||||
def build_model(rank, world_size):
|
||||
config = T5Config.from_pretrained("t5-small")
|
||||
config.dropout_rate = 0
|
||||
org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda')
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
|
||||
org_model_for_shard = copy.deepcopy(org_model)
|
||||
|
||||
sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward(org_model, sharded_model):
|
||||
|
||||
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||
return_tensors="pt").input_ids.to('cuda')
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_output = org_model.generate(input_ids)
|
||||
|
||||
#shard model
|
||||
sharded_model.eval()
|
||||
shard_output = sharded_model.generate(input_ids)
|
||||
assert torch.allclose(
|
||||
org_output[0], shard_output[0],
|
||||
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||
|
||||
|
||||
def check_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||
return_tensors="pt").input_ids.to('cuda')
|
||||
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
|
||||
|
||||
#orgin model
|
||||
org_model.train()
|
||||
org_loss = org_model(input_ids=input_ids, labels=labels).loss
|
||||
org_loss.backward()
|
||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
|
||||
#shard model
|
||||
sharded_model.train()
|
||||
shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss
|
||||
shard_loss.backward()
|
||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
def check_t5(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
org_model, sharded_model = build_model(rank, world_size)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_t5():
|
||||
spawn(check_t5, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
Loading…
Reference in New Issue