[shardformer] supported T5 and its variants (#4045)

pull/4157/head
Frank Lee 2023-06-19 17:57:37 +08:00
parent c1d5453e9f
commit d857f3dbba
10 changed files with 316 additions and 221 deletions

View File

@ -81,8 +81,8 @@ We will follow this roadmap to develop Shardformer:
- [ ] Hugging Face
- [ ] NLP
- [x] BERT
- [ ] T5
- [ ] LlaMa
- [x] T5
- [x] LlaMa
- [ ] GPT2
- [ ] BLOOM
- [ ] RoBERTa
@ -90,7 +90,6 @@ We will follow this roadmap to develop Shardformer:
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] CV
- [ ] ViT
- [ ] BEiT

View File

@ -469,13 +469,14 @@ class Embedding1D(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
@ -499,7 +500,9 @@ class Embedding1D(ParallelModule):
@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
@ -527,7 +530,9 @@ class Embedding1D(ParallelModule):
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)
# copy the weight
with torch.no_grad():
@ -537,7 +542,7 @@ class Embedding1D(ParallelModule):
return embedding
def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
@ -548,9 +553,12 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel
class VocabParallelEmbedding1D(ParallelLayer):
@ -595,7 +603,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
@ -610,7 +618,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype))
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@ -662,7 +670,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

View File

@ -48,6 +48,12 @@ _POLICY_LIST = {
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
# T5
"transformers.models.t5.modeling_t5.T5Model":
PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration":
PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"),
"transformers.models.t5.modeling_t5.T5EncoderModel":
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
# GPT2
}

View File

@ -27,6 +27,7 @@ class SubModuleReplacementDescription:
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
@dataclass

View File

@ -1,159 +1,173 @@
from typing import Dict
import torch
import torch.nn as nn
from torch.nn import Embedding
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5Block,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Model,
T5Stack,
)
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
class T5ModelPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
print('config heads', config.num_heads)
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
return {
T5Stack:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
T5Block:
Argument(attr_dict={}, param_funcs=[]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]),
T5LayerSelfAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
),
]),
T5LayerCrossAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]),
T5Attention:
Argument(attr_dict={
"d_model": config.d_model // world_size,
"n_heads": config.num_heads // world_size,
"inner_dim": config.num_heads * config.d_kv // world_size,
ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
param_funcs=[T5ModelPolicy.attn_layer]),
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
]),
T5LayerFF:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
),
]),
T5DenseGatedActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="wo",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]),
T5DenseActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
])
}
@staticmethod
def dense_gated_layer():
return [
Col_Layer(
suffix="wi_0",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wi_1",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
),
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
]
def new_model_class(self):
return None
@staticmethod
def dense_act_layer():
return [
Col_Layer(
suffix="wi",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wo",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
)
]
@staticmethod
def attn_layer():
return [
Col_Layer(
suffix="q",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="k",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="v",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="o",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def dropout():
return [Dropout_Layer(
suffix="dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)]
@staticmethod
def embedding():
return [
Embedding_Layer(
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
weight="weight",
replace_layer=col_nn.Embedding1D,
gather_output=False,
)
]
from transformers import T5ForConditionalGeneration
def postprocess(self):
return self.model
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = T5ModelPolicy.argument_policy(config, world_size)
argument = {
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
def module_policy(self):
policy = super().module_policy()
new_item = {
T5ForConditionalGeneration:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
argument.update(base_argument)
return argument
@staticmethod
def lm_head():
return [Col_Layer(
suffix="lm_head",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)]
policy.update(new_item)
return policy
from transformers import T5EncoderModel
class T5EncoderModelPolicy(T5ModelPolicy):
class T5EncoderPolicy(T5ModelPolicy):
pass

View File

@ -175,7 +175,16 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix)
native_sub_module = getattr_(org_layer, suffix, ignore=True)
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
# if it is None and we are allowed to ignore this module
# just skip
if description.ignore_if_not_exist and native_sub_module is None:
continue
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)

View File

@ -3,6 +3,7 @@ from .comparison import (
assert_close_loose,
assert_equal,
assert_equal_in_group,
assert_hf_output_close,
assert_not_equal,
check_state_dict_equal,
)
@ -20,5 +21,5 @@ from .utils import (
__all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal'
'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close'
]

View File

@ -1,4 +1,4 @@
from typing import OrderedDict
from typing import Any, List, OrderedDict
import torch
import torch.distributed as dist
@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
assert torch.equal(v, d2[k])
else:
assert v == d2[k]
def assert_hf_output_close(out1: Any,
out2: Any,
ignore_keys: List[str] = None,
track_name: str = "",
atol=1e-5,
rtol=1e-5):
"""
Check if two outputs from huggingface are equal.
Args:
out1 (Any): the first output
out2 (Any): the second output
ignore_keys (List[str]): the keys to ignore when comparing two dicts
track_name (str): the name of the value compared, used to track the path
"""
if isinstance(out1, dict) and isinstance(out2, dict):
# if two values are dict
# we recursively check the keys
assert set(out1.keys()) == set(out2.keys())
for k in out1.keys():
if ignore_keys is not None and k in ignore_keys:
continue
assert_hf_output_close(out1[k],
out2[k],
track_name=f"{track_name}.{k}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol)
elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
# if two values are list
# we recursively check the elements
assert len(out1) == len(out2)
for i in range(len(out1)):
assert_hf_output_close(out1[i],
out2[i],
track_name=f"{track_name}.{i}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol)
elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
if out1.shape != out2.shape:
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"

View File

@ -9,7 +9,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassifi
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
@ -17,7 +17,11 @@ tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokeni
def build_model(world_size, model_fn):
# create new model
config = LlamaConfig(num_hidden_layers=8)
config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128)
org_model = model_fn(config).cuda()
# shard model
@ -30,49 +34,47 @@ def build_model(world_size, model_fn):
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input = 'Hello, my dog is cute'
inputs = tokenizer(input, return_tensors='pt').to('cuda')
del inputs["token_type_ids"]
del inputs["attention_mask"]
#orgin model
org_model.eval()
org_out = org_model(**inputs)
#shard model
sharded_model.eval()
shard_out = sharded_model(**inputs)
assert torch.allclose(
org_out[0], shard_out[0],
atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
def check_forward_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
del tokenized_input["token_type_ids"]
del tokenized_input["attention_mask"]
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
#orgin model
# switch to train mode
org_model.train()
org_out = org_model(**tokenized_input)
org_loss = org_out.loss
org_loss.backward()
org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad
torch.cuda.empty_cache()
#shard model
sharded_model.train()
shard_out = sharded_model(**tokenized_input)
shard_loss = shard_out.loss
if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
org_output = org_model(**tokenized_input)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, LlamaForCausalLM):
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
org_output = org_model(**tokenized_input)
org_loss = org_output.loss
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.loss
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
# run backward
org_loss.backward()
shard_loss.backward()
shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad
# check grad
if isinstance(org_model, LlamaModel):
llama_model = org_model
shard_llama_model = sharded_model
else:
llama_model = org_model.model
shard_llama_model = sharded_model.model
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
@ -88,23 +90,23 @@ def check_llama(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [
LlamaForCausalLM,
LlamaModel,
# LlamaForCausalLM,
# TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification
]
for model_fn in model_list:
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
check_forward_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 4)

View File

@ -1,71 +1,72 @@
import copy
import os
import random
import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
def build_model(rank, world_size):
config = T5Config.from_pretrained("t5-small")
def build_model(world_size, model_fn):
config = T5Config(decoder_start_token_id=0)
config.dropout_rate = 0
org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda')
org_model = model_fn(config=config).to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size)
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
org_model_for_shard = copy.deepcopy(org_model)
sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda')
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
#orgin model
org_model.eval()
org_output = org_model.generate(input_ids)
#shard model
sharded_model.eval()
shard_output = sharded_model.generate(input_ids)
assert torch.allclose(
org_output[0], shard_output[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
def check_forward_backward(org_model, sharded_model):
# prepare input
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
#orgin model
# switch to train mode
org_model.train()
org_loss = org_model(input_ids=input_ids, labels=labels).loss
org_loss.backward()
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
#shard model
sharded_model.train()
shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss
if isinstance(org_model, T5ForConditionalGeneration):
org_output = org_model(input_ids=input_ids, labels=labels)
org_loss = org_output.loss
shard_output = sharded_model(input_ids=input_ids, labels=labels)
shard_loss = shard_output.loss
elif isinstance(org_model, T5Model):
decoder_input_ids = org_model._shift_right(input_ids)
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, T5EncoderModel):
org_output = org_model(input_ids=input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids)
shard_loss = shard_output.last_hidden_state.mean()
# key is sharded, so we ignore
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward()
# check grad equality
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
@ -82,16 +83,21 @@ def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
model_fn_list = [
T5Model,
T5ForConditionalGeneration,
T5EncoderModel,
]
torch.cuda.empty_cache()
for model_fn in model_fn_list:
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist
@pytest.mark.skip
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5():
spawn(check_t5, 2)