mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] supported T5 and its variants (#4045)
parent
c1d5453e9f
commit
d857f3dbba
|
@ -81,8 +81,8 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [ ] Hugging Face
|
||||
- [ ] NLP
|
||||
- [x] BERT
|
||||
- [ ] T5
|
||||
- [ ] LlaMa
|
||||
- [x] T5
|
||||
- [x] LlaMa
|
||||
- [ ] GPT2
|
||||
- [ ] BLOOM
|
||||
- [ ] RoBERTa
|
||||
|
@ -90,7 +90,6 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [ ] ERNIE
|
||||
- [ ] GPT Neo
|
||||
- [ ] GPT-J
|
||||
- [ ] CV
|
||||
- [ ] CV
|
||||
- [ ] ViT
|
||||
- [ ] BEiT
|
||||
|
|
|
@ -469,13 +469,14 @@ class Embedding1D(ParallelModule):
|
|||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
|
@ -499,7 +500,9 @@ class Embedding1D(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
|
||||
*args,
|
||||
**kwargs) -> "Embedding1D":
|
||||
r"""
|
||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||
"""
|
||||
|
@ -527,7 +530,9 @@ class Embedding1D(ParallelModule):
|
|||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
sparse=sparse,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
|
@ -537,7 +542,7 @@ class Embedding1D(ParallelModule):
|
|||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
@ -548,9 +553,12 @@ class Embedding1D(ParallelModule):
|
|||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
|
||||
return output
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(ParallelLayer):
|
||||
|
@ -595,7 +603,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
**kwargs):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
@ -610,7 +618,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
|
||||
torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype))
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
@ -662,7 +670,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
|
|
@ -48,6 +48,12 @@ _POLICY_LIST = {
|
|||
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
|
||||
|
||||
# T5
|
||||
"transformers.models.t5.modeling_t5.T5Model":
|
||||
PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration":
|
||||
PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5EncoderModel":
|
||||
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
|
||||
|
||||
# GPT2
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ class SubModuleReplacementDescription:
|
|||
suffix: str
|
||||
target_module: ParallelModule
|
||||
kwargs: Dict[str, Any] = None
|
||||
ignore_if_not_exist: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,159 +1,173 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Embedding
|
||||
from transformers import T5ForConditionalGeneration
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5Block,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Model,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
|
||||
|
||||
class T5ModelPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
print('config heads', config.num_heads)
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
return {
|
||||
T5Stack:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
|
||||
T5Block:
|
||||
Argument(attr_dict={}, param_funcs=[]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
]),
|
||||
T5LayerCrossAttention:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
T5Attention:
|
||||
Argument(attr_dict={
|
||||
"d_model": config.d_model // world_size,
|
||||
"n_heads": config.num_heads // world_size,
|
||||
"inner_dim": config.num_heads * config.d_kv // world_size,
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"n_heads":
|
||||
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
param_funcs=[T5ModelPolicy.attn_layer]),
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
ignore_if_not_exist=True)
|
||||
]),
|
||||
T5LayerFF:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
]),
|
||||
T5DenseGatedActDense:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
T5DenseActDense:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def dense_gated_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="wi_0",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="wi_1",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
|
||||
]
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def dense_act_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="wi",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="wo",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="q",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="k",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="v",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="o",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def dropout():
|
||||
return [Dropout_Layer(
|
||||
suffix="dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def embedding():
|
||||
return [
|
||||
Embedding_Layer(
|
||||
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Embedding1D,
|
||||
gather_output=False,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = T5ModelPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
new_item = {
|
||||
T5ForConditionalGeneration:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def lm_head():
|
||||
return [Col_Layer(
|
||||
suffix="lm_head",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
)]
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
|
||||
class T5EncoderModelPolicy(T5ModelPolicy):
|
||||
class T5EncoderPolicy(T5ModelPolicy):
|
||||
pass
|
||||
|
|
|
@ -175,7 +175,16 @@ class ModelSharder(object):
|
|||
assert target_module is not None, 'target_module should not be None'
|
||||
|
||||
# TODO: support different parallel mode
|
||||
native_sub_module = getattr_(org_layer, suffix)
|
||||
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
||||
|
||||
assert not isinstance(native_sub_module, target_module), \
|
||||
f"The module with suffix {suffix} has been replaced, please check the policy"
|
||||
|
||||
# if it is None and we are allowed to ignore this module
|
||||
# just skip
|
||||
if description.ignore_if_not_exist and native_sub_module is None:
|
||||
continue
|
||||
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
**kwargs)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from .comparison import (
|
|||
assert_close_loose,
|
||||
assert_equal,
|
||||
assert_equal_in_group,
|
||||
assert_hf_output_close,
|
||||
assert_not_equal,
|
||||
check_state_dict_equal,
|
||||
)
|
||||
|
@ -20,5 +21,5 @@ from .utils import (
|
|||
__all__ = [
|
||||
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
|
||||
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
|
||||
'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal'
|
||||
'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import OrderedDict
|
||||
from typing import Any, List, OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
|
|||
assert torch.equal(v, d2[k])
|
||||
else:
|
||||
assert v == d2[k]
|
||||
|
||||
|
||||
def assert_hf_output_close(out1: Any,
|
||||
out2: Any,
|
||||
ignore_keys: List[str] = None,
|
||||
track_name: str = "",
|
||||
atol=1e-5,
|
||||
rtol=1e-5):
|
||||
"""
|
||||
Check if two outputs from huggingface are equal.
|
||||
|
||||
Args:
|
||||
out1 (Any): the first output
|
||||
out2 (Any): the second output
|
||||
ignore_keys (List[str]): the keys to ignore when comparing two dicts
|
||||
track_name (str): the name of the value compared, used to track the path
|
||||
"""
|
||||
if isinstance(out1, dict) and isinstance(out2, dict):
|
||||
# if two values are dict
|
||||
# we recursively check the keys
|
||||
assert set(out1.keys()) == set(out2.keys())
|
||||
for k in out1.keys():
|
||||
if ignore_keys is not None and k in ignore_keys:
|
||||
continue
|
||||
assert_hf_output_close(out1[k],
|
||||
out2[k],
|
||||
track_name=f"{track_name}.{k}",
|
||||
ignore_keys=ignore_keys,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
|
||||
# if two values are list
|
||||
# we recursively check the elements
|
||||
assert len(out1) == len(out2)
|
||||
for i in range(len(out1)):
|
||||
assert_hf_output_close(out1[i],
|
||||
out2[i],
|
||||
track_name=f"{track_name}.{i}",
|
||||
ignore_keys=ignore_keys,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
|
||||
if out1.shape != out2.shape:
|
||||
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
|
||||
assert torch.allclose(
|
||||
out1, out2, atol=atol, rtol=rtol
|
||||
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
|
||||
else:
|
||||
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"
|
||||
|
|
|
@ -9,7 +9,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassifi
|
|||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
@ -17,7 +17,11 @@ tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokeni
|
|||
|
||||
def build_model(world_size, model_fn):
|
||||
# create new model
|
||||
config = LlamaConfig(num_hidden_layers=8)
|
||||
config = LlamaConfig(num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128)
|
||||
org_model = model_fn(config).cuda()
|
||||
|
||||
# shard model
|
||||
|
@ -30,49 +34,47 @@ def build_model(world_size, model_fn):
|
|||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward(org_model, sharded_model):
|
||||
input = 'Hello, my dog is cute'
|
||||
inputs = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
del inputs["token_type_ids"]
|
||||
del inputs["attention_mask"]
|
||||
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_out = org_model(**inputs)
|
||||
|
||||
#shard model
|
||||
sharded_model.eval()
|
||||
shard_out = sharded_model(**inputs)
|
||||
|
||||
assert torch.allclose(
|
||||
org_out[0], shard_out[0],
|
||||
atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||
|
||||
|
||||
def check_backward(org_model, sharded_model):
|
||||
def check_forward_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
del tokenized_input["token_type_ids"]
|
||||
del tokenized_input["attention_mask"]
|
||||
labels = tokenized_input['input_ids'].clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
tokenized_input['labels'] = labels
|
||||
|
||||
#orgin model
|
||||
# switch to train mode
|
||||
org_model.train()
|
||||
org_out = org_model(**tokenized_input)
|
||||
org_loss = org_out.loss
|
||||
org_loss.backward()
|
||||
org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
#shard model
|
||||
sharded_model.train()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_out.loss
|
||||
|
||||
if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
|
||||
org_output = org_model(**tokenized_input)
|
||||
org_loss = org_output.last_hidden_state.mean()
|
||||
shard_output = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_output.last_hidden_state.mean()
|
||||
elif isinstance(org_model, LlamaForCausalLM):
|
||||
labels = tokenized_input['input_ids'].clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
tokenized_input['labels'] = labels
|
||||
org_output = org_model(**tokenized_input)
|
||||
org_loss = org_output.loss
|
||||
shard_output = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_output.loss
|
||||
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
|
||||
|
||||
# run backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad
|
||||
|
||||
# check grad
|
||||
if isinstance(org_model, LlamaModel):
|
||||
llama_model = org_model
|
||||
shard_llama_model = sharded_model
|
||||
else:
|
||||
llama_model = org_model.model
|
||||
shard_llama_model = sharded_model.model
|
||||
|
||||
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
|
@ -88,23 +90,23 @@ def check_llama(rank, world_size, port):
|
|||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
model_list = [
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
# LlamaForCausalLM,
|
||||
|
||||
# TODO: do not work yet
|
||||
# LlamaModel,
|
||||
# LlamaForSequenceClassification
|
||||
]
|
||||
|
||||
for model_fn in model_list:
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
check_forward_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 4)
|
||||
|
||||
|
|
|
@ -1,71 +1,72 @@
|
|||
import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer
|
||||
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
|
||||
def build_model(rank, world_size):
|
||||
config = T5Config.from_pretrained("t5-small")
|
||||
def build_model(world_size, model_fn):
|
||||
config = T5Config(decoder_start_token_id=0)
|
||||
config.dropout_rate = 0
|
||||
org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda')
|
||||
org_model = model_fn(config=config).to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
|
||||
org_model_for_shard = copy.deepcopy(org_model)
|
||||
|
||||
sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda')
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward(org_model, sharded_model):
|
||||
|
||||
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||
return_tensors="pt").input_ids.to('cuda')
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_output = org_model.generate(input_ids)
|
||||
|
||||
#shard model
|
||||
sharded_model.eval()
|
||||
shard_output = sharded_model.generate(input_ids)
|
||||
assert torch.allclose(
|
||||
org_output[0], shard_output[0],
|
||||
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||
|
||||
|
||||
def check_backward(org_model, sharded_model):
|
||||
def check_forward_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||
return_tensors="pt").input_ids.to('cuda')
|
||||
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
|
||||
|
||||
#orgin model
|
||||
# switch to train mode
|
||||
org_model.train()
|
||||
org_loss = org_model(input_ids=input_ids, labels=labels).loss
|
||||
org_loss.backward()
|
||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
|
||||
#shard model
|
||||
sharded_model.train()
|
||||
shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss
|
||||
|
||||
if isinstance(org_model, T5ForConditionalGeneration):
|
||||
org_output = org_model(input_ids=input_ids, labels=labels)
|
||||
org_loss = org_output.loss
|
||||
shard_output = sharded_model(input_ids=input_ids, labels=labels)
|
||||
shard_loss = shard_output.loss
|
||||
elif isinstance(org_model, T5Model):
|
||||
decoder_input_ids = org_model._shift_right(input_ids)
|
||||
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
org_loss = org_output.last_hidden_state.mean()
|
||||
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
shard_loss = shard_output.last_hidden_state.mean()
|
||||
elif isinstance(org_model, T5EncoderModel):
|
||||
org_output = org_model(input_ids=input_ids)
|
||||
org_loss = org_output.last_hidden_state.mean()
|
||||
shard_output = sharded_model(input_ids=input_ids)
|
||||
shard_loss = shard_output.last_hidden_state.mean()
|
||||
|
||||
# key is sharded, so we ignore
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad equality
|
||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
|
@ -82,16 +83,21 @@ def check_t5(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
org_model, sharded_model = build_model(rank, world_size)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
model_fn_list = [
|
||||
T5Model,
|
||||
T5ForConditionalGeneration,
|
||||
T5EncoderModel,
|
||||
]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
for model_fn in model_fn_list:
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward_backward(org_model, sharded_model)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_t5():
|
||||
spawn(check_t5, 2)
|
||||
|
||||
|
|
Loading…
Reference in New Issue