[shardformer] supported T5 and its variants (#4045)

pull/4157/head
Frank Lee 2023-06-19 17:57:37 +08:00
parent c1d5453e9f
commit d857f3dbba
10 changed files with 316 additions and 221 deletions

View File

@ -81,8 +81,8 @@ We will follow this roadmap to develop Shardformer:
- [ ] Hugging Face - [ ] Hugging Face
- [ ] NLP - [ ] NLP
- [x] BERT - [x] BERT
- [ ] T5 - [x] T5
- [ ] LlaMa - [x] LlaMa
- [ ] GPT2 - [ ] GPT2
- [ ] BLOOM - [ ] BLOOM
- [ ] RoBERTa - [ ] RoBERTa
@ -90,7 +90,6 @@ We will follow this roadmap to develop Shardformer:
- [ ] ERNIE - [ ] ERNIE
- [ ] GPT Neo - [ ] GPT Neo
- [ ] GPT-J - [ ] GPT-J
- [ ] CV
- [ ] CV - [ ] CV
- [ ] ViT - [ ] ViT
- [ ] BEiT - [ ] BEiT

View File

@ -469,13 +469,14 @@ class Embedding1D(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
super().__init__() super().__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embedding_dim = embedding_dim
self.process_group = process_group self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group) self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
@ -499,7 +500,9 @@ class Embedding1D(ParallelModule):
@staticmethod @staticmethod
def from_native_module(module: nn.Embedding, def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r""" r"""
Build a 1D parallelized Embedding from a native nn.Embedding module. Build a 1D parallelized Embedding from a native nn.Embedding module.
""" """
@ -527,7 +530,9 @@ class Embedding1D(ParallelModule):
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type, norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse) sparse=sparse,
*args,
**kwargs)
# copy the weight # copy the weight
with torch.no_grad(): with torch.no_grad():
@ -537,7 +542,7 @@ class Embedding1D(ParallelModule):
return embedding return embedding
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
@ -548,9 +553,12 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel
class VocabParallelEmbedding1D(ParallelLayer): class VocabParallelEmbedding1D(ParallelLayer):
@ -595,7 +603,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
**kwargs): **kwargs):
super().__init__() super().__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embedding_dim = embedding_dim
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embed_args = args self.embed_args = args
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
@ -610,7 +618,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter( self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype))
# offset the seed with randomizer index and rank # offset the seed with randomizer index and rank
seed = torch.random.initial_seed() seed = torch.random.initial_seed()
@ -662,7 +670,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()

View File

@ -48,6 +48,12 @@ _POLICY_LIST = {
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
# T5 # T5
"transformers.models.t5.modeling_t5.T5Model":
PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration":
PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"),
"transformers.models.t5.modeling_t5.T5EncoderModel":
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
# GPT2 # GPT2
} }

View File

@ -27,6 +27,7 @@ class SubModuleReplacementDescription:
suffix: str suffix: str
target_module: ParallelModule target_module: ParallelModule
kwargs: Dict[str, Any] = None kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
@dataclass @dataclass

View File

@ -1,159 +1,173 @@
from typing import Dict import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Embedding from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import ( from transformers.models.t5.modeling_t5 import (
T5Attention, T5Attention,
T5Block,
T5DenseActDense, T5DenseActDense,
T5DenseGatedActDense, T5DenseGatedActDense,
T5LayerCrossAttention, T5LayerCrossAttention,
T5LayerFF, T5LayerFF,
T5LayerSelfAttention, T5LayerSelfAttention,
T5Model,
T5Stack, T5Stack,
) )
import colossalai.shardformer.layer.layers as col_nn from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
class T5ModelPolicy(Policy): class T5ModelPolicy(Policy):
@staticmethod def preprocess(self):
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: # reshape the embedding layer
print('config heads', config.num_heads) r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
return { return {
T5Stack: T5Stack:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]), ModulePolicyDescription(attribute_replacement={},
T5Block: param_replacement=[],
Argument(attr_dict={}, param_funcs=[]), sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]),
T5LayerSelfAttention: T5LayerSelfAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
),
]),
T5LayerCrossAttention: T5LayerCrossAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]),
T5Attention: T5Attention:
Argument(attr_dict={ ModulePolicyDescription(attribute_replacement={
"d_model": config.d_model // world_size, "d_model":
"n_heads": config.num_heads // world_size, self.model.config.d_model // self.shard_config.tensor_parallel_size,
"inner_dim": config.num_heads * config.d_kv // world_size, "n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
}, },
param_funcs=[T5ModelPolicy.attn_layer]), param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
]),
T5LayerFF: T5LayerFF:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
),
]),
T5DenseGatedActDense: T5DenseGatedActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]), ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="wo",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]),
T5DenseActDense: T5DenseActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]), ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
])
} }
@staticmethod def new_model_class(self):
def dense_gated_layer(): return None
return [
Col_Layer(
suffix="wi_0",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wi_1",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
),
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
]
@staticmethod def postprocess(self):
def dense_act_layer(): return self.model
return [
Col_Layer(
suffix="wi",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wo",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
)
]
@staticmethod
def attn_layer():
return [
Col_Layer(
suffix="q",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="k",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="v",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="o",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def dropout():
return [Dropout_Layer(
suffix="dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)]
@staticmethod
def embedding():
return [
Embedding_Layer(
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
weight="weight",
replace_layer=col_nn.Embedding1D,
gather_output=False,
)
]
from transformers import T5ForConditionalGeneration
class T5ForConditionalGenerationPolicy(T5ModelPolicy): class T5ForConditionalGenerationPolicy(T5ModelPolicy):
@staticmethod def module_policy(self):
def argument_policy(config, world_size): policy = super().module_policy()
base_argument = T5ModelPolicy.argument_policy(config, world_size)
argument = { new_item = {
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head]) T5ForConditionalGeneration:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
} }
argument.update(base_argument)
return argument
@staticmethod policy.update(new_item)
def lm_head(): return policy
return [Col_Layer(
suffix="lm_head",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)]
from transformers import T5EncoderModel class T5EncoderPolicy(T5ModelPolicy):
class T5EncoderModelPolicy(T5ModelPolicy):
pass pass

View File

@ -175,7 +175,16 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None' assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode # TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix) native_sub_module = getattr_(org_layer, suffix, ignore=True)
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
# if it is None and we are allowed to ignore this module
# just skip
if description.ignore_if_not_exist and native_sub_module is None:
continue
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs) **kwargs)

View File

@ -3,6 +3,7 @@ from .comparison import (
assert_close_loose, assert_close_loose,
assert_equal, assert_equal,
assert_equal_in_group, assert_equal_in_group,
assert_hf_output_close,
assert_not_equal, assert_not_equal,
check_state_dict_equal, check_state_dict_equal,
) )
@ -20,5 +21,5 @@ from .utils import (
__all__ = [ __all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal' 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close'
] ]

View File

@ -1,4 +1,4 @@
from typing import OrderedDict from typing import Any, List, OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
assert torch.equal(v, d2[k]) assert torch.equal(v, d2[k])
else: else:
assert v == d2[k] assert v == d2[k]
def assert_hf_output_close(out1: Any,
out2: Any,
ignore_keys: List[str] = None,
track_name: str = "",
atol=1e-5,
rtol=1e-5):
"""
Check if two outputs from huggingface are equal.
Args:
out1 (Any): the first output
out2 (Any): the second output
ignore_keys (List[str]): the keys to ignore when comparing two dicts
track_name (str): the name of the value compared, used to track the path
"""
if isinstance(out1, dict) and isinstance(out2, dict):
# if two values are dict
# we recursively check the keys
assert set(out1.keys()) == set(out2.keys())
for k in out1.keys():
if ignore_keys is not None and k in ignore_keys:
continue
assert_hf_output_close(out1[k],
out2[k],
track_name=f"{track_name}.{k}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol)
elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
# if two values are list
# we recursively check the elements
assert len(out1) == len(out2)
for i in range(len(out1)):
assert_hf_output_close(out1[i],
out2[i],
track_name=f"{track_name}.{i}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol)
elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
if out1.shape != out2.shape:
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"

View File

@ -9,7 +9,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassifi
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
@ -17,7 +17,11 @@ tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokeni
def build_model(world_size, model_fn): def build_model(world_size, model_fn):
# create new model # create new model
config = LlamaConfig(num_hidden_layers=8) config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128)
org_model = model_fn(config).cuda() org_model = model_fn(config).cuda()
# shard model # shard model
@ -30,49 +34,47 @@ def build_model(world_size, model_fn):
return org_model, sharded_model return org_model, sharded_model
def check_forward(org_model, sharded_model): def check_forward_backward(org_model, sharded_model):
input = 'Hello, my dog is cute'
inputs = tokenizer(input, return_tensors='pt').to('cuda')
del inputs["token_type_ids"]
del inputs["attention_mask"]
#orgin model
org_model.eval()
org_out = org_model(**inputs)
#shard model
sharded_model.eval()
shard_out = sharded_model(**inputs)
assert torch.allclose(
org_out[0], shard_out[0],
atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input # prepare input
input = 'Hello, my dog is cute' input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
del tokenized_input["token_type_ids"] del tokenized_input["token_type_ids"]
del tokenized_input["attention_mask"] del tokenized_input["attention_mask"]
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
#orgin model # switch to train mode
org_model.train() org_model.train()
org_out = org_model(**tokenized_input)
org_loss = org_out.loss
org_loss.backward()
org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad
torch.cuda.empty_cache()
#shard model
sharded_model.train() sharded_model.train()
shard_out = sharded_model(**tokenized_input)
shard_loss = shard_out.loss if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
org_output = org_model(**tokenized_input)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, LlamaForCausalLM):
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
org_output = org_model(**tokenized_input)
org_loss = org_output.loss
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.loss
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
# run backward
org_loss.backward()
shard_loss.backward() shard_loss.backward()
shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad
# check grad
if isinstance(org_model, LlamaModel):
llama_model = org_model
shard_llama_model = sharded_model
else:
llama_model = org_model.model
shard_llama_model = sharded_model.model
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) all_shard_grad = torch.cat(shard_grad_list, dim=0)
@ -88,23 +90,23 @@ def check_llama(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [ model_list = [
LlamaForCausalLM, LlamaModel,
# LlamaForCausalLM,
# TODO: do not work yet # TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification # LlamaForSequenceClassification
] ]
for model_fn in model_list: for model_fn in model_list:
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model) check_forward_backward(org_model, sharded_model)
check_backward(org_model, sharded_model)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama(): def test_llama():
spawn(check_llama, 4) spawn(check_llama, 4)

View File

@ -1,71 +1,72 @@
import copy import copy
import os import os
import random
import pytest import pytest
import torch import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, ShardFormer from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = T5Tokenizer.from_pretrained("t5-small")
def build_model(rank, world_size): def build_model(world_size, model_fn):
config = T5Config.from_pretrained("t5-small") config = T5Config(decoder_start_token_id=0)
config.dropout_rate = 0 config.dropout_rate = 0
org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda') org_model = model_fn(config=config).to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size)
shardconfig = ShardConfig( # shard model
rank=rank, shard_config = ShardConfig(tensor_parallel_size=world_size)
world_size=world_size, model_copy = copy.deepcopy(org_model)
gather_output=True, shard_former = ShardFormer(shard_config=shard_config)
) shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
org_model_for_shard = copy.deepcopy(org_model)
sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda')
return org_model, sharded_model return org_model, sharded_model
def check_forward(org_model, sharded_model): def check_forward_backward(org_model, sharded_model):
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
#orgin model
org_model.eval()
org_output = org_model.generate(input_ids)
#shard model
sharded_model.eval()
shard_output = sharded_model.generate(input_ids)
assert torch.allclose(
org_output[0], shard_output[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input # prepare input
input_ids = tokenizer("translate English to German: The house is wonderful.", input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda') return_tensors="pt").input_ids.to('cuda')
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
#orgin model # switch to train mode
org_model.train() org_model.train()
org_loss = org_model(input_ids=input_ids, labels=labels).loss
org_loss.backward()
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
#shard model
sharded_model.train() sharded_model.train()
shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss
if isinstance(org_model, T5ForConditionalGeneration):
org_output = org_model(input_ids=input_ids, labels=labels)
org_loss = org_output.loss
shard_output = sharded_model(input_ids=input_ids, labels=labels)
shard_loss = shard_output.loss
elif isinstance(org_model, T5Model):
decoder_input_ids = org_model._shift_right(input_ids)
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, T5EncoderModel):
org_output = org_model(input_ids=input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids)
shard_loss = shard_output.last_hidden_state.mean()
# key is sharded, so we ignore
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward() shard_loss.backward()
# check grad equality
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
@ -82,16 +83,21 @@ def check_t5(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
org_model, sharded_model = build_model(rank, world_size) model_fn_list = [
check_forward(org_model, sharded_model) T5Model,
check_backward(org_model, sharded_model) T5ForConditionalGeneration,
T5EncoderModel,
]
torch.cuda.empty_cache() for model_fn in model_fn_list:
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5(): def test_t5():
spawn(check_t5, 2) spawn(check_t5, 2)