Browse Source

[shardformer] adapted T5 and LLaMa test to use kit (#4049)

* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
pull/4157/head
Frank Lee 1 year ago
parent
commit
58df720570
  1. 22
      colossalai/shardformer/layer/embedding1d.py
  2. 2
      colossalai/shardformer/policies/llama.py
  3. 3
      colossalai/shardformer/policies/t5.py
  4. 11
      colossalai/shardformer/shard/sharder.py
  5. 2
      colossalai/testing/comparison.py
  6. 30
      tests/kit/model_zoo/registry.py
  7. 1
      tests/kit/model_zoo/transformers/__init__.py
  8. 76
      tests/kit/model_zoo/transformers/llama.py
  9. 53
      tests/kit/model_zoo/transformers/t5.py
  10. 2
      tests/test_booster/test_mixed_precision/test_fp16_torch.py
  11. 2
      tests/test_booster/test_plugin/test_gemini_plugin.py
  12. 2
      tests/test_booster/test_plugin/test_low_level_zero_plugin.py
  13. 2
      tests/test_booster/test_plugin/test_torch_ddp_plugin.py
  14. 2
      tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
  15. 4
      tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
  16. 2
      tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
  17. 2
      tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py
  18. 2
      tests/test_lazy/lazy_init_utils.py
  19. 2
      tests/test_lazy/test_distribute.py
  20. 0
      tests/test_shardformer/__init__.py
  21. 0
      tests/test_shardformer/test_model/__init__.py
  22. 38
      tests/test_shardformer/test_model/_utils.py
  23. 78
      tests/test_shardformer/test_model/test_shard_llama.py
  24. 75
      tests/test_shardformer/test_model/test_shard_t5.py

22
colossalai/shardformer/layer/embedding1d.py

@ -65,13 +65,14 @@ class Embedding1D(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
@ -79,7 +80,7 @@ class Embedding1D(ParallelModule):
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
# self.gather_output = gather_output
self.gather_output = gather_output
if device is None:
device = get_current_device()
@ -95,7 +96,9 @@ class Embedding1D(ParallelModule):
@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
@ -123,7 +126,9 @@ class Embedding1D(ParallelModule):
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)
# copy the weight
with torch.no_grad():
@ -133,7 +138,7 @@ class Embedding1D(ParallelModule):
return embedding
def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
@ -144,6 +149,9 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel

2
colossalai/shardformer/policies/llama.py

@ -4,7 +4,7 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

3
colossalai/shardformer/policies/t5.py

@ -11,8 +11,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack,
)
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

11
colossalai/shardformer/shard/sharder.py

@ -185,7 +185,14 @@ class ModelSharder(object):
if description.ignore_if_not_exist and native_sub_module is None:
continue
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)
setattr_(org_layer, suffix, replace_layer)

2
colossalai/testing/comparison.py

@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"

30
tests/kit/model_zoo/registry.py

@ -28,27 +28,35 @@ class ModelZooRegistry(dict):
model_fn: Callable,
data_gen_fn: Callable,
output_transform_fn: Callable,
loss_fn: Callable = None,
model_attribute: ModelAttribute = None):
"""
Register a model and data generation function.
Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)
```python
# normal forward workflow
model = resnet18()
data = resnet18_data_gen()
output = model(**data)
transformed_output = output_transform_fn(output)
loss = loss_fn(transformed_output)
# Register
model_zoo = ModelZooRegistry()
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
```
Args:
name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict.
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: str):
"""

1
tests/kit/model_zoo/transformers/__init__.py

@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .gpt import *
from .llama import *
from .opt import *
from .t5 import *

76
tests/kit/model_zoo/transformers/llama.py

@ -0,0 +1,76 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
try:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False
if HAS_LLAMA:
# ===============================
# Register LLaMA
# ===============================
def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import LlamaTokenizerFast
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data
# transform the output to a dict
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(name='transformers_llama',
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_casual_lm',
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_sequence_classification',
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True))

53
tests/kit/model_zoo/transformers/t5.py

@ -6,24 +6,50 @@ from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence T5
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
# define data gen function
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
# Generated from following code snippet
#
# from transformers import T5Config, T5Tokenizer
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
return dict(input_ids=input_ids)
def data_gen_for_conditional_generation():
# labels is generated with the following code
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data
def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data
# output transform function
output_transform_fn = lambda x: x
config = transformers.T5Config(d_model=128, num_layers=2)
# define loss funciton
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
# register the following models
# transformers.T5Model,
@ -31,16 +57,19 @@ config = transformers.T5Config(d_model=128, num_layers=2)
# transformers.T5EncoderModel,
model_zoo.register(name='transformers_t5',
model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_for_conditional_generation',
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_encoder_model',
model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True))

2
tests/test_booster/test_mixed_precision/test_fp16_torch.py

@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch':
continue

2
tests/test_booster/test_plugin/test_gemini_plugin.py

@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):

2
tests/test_booster/test_plugin/test_low_level_zero_plugin.py

@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# FIXME(ver217): fix these models
if name in ignore_models:
skipped_models.append(name)

2
tests/test_booster/test_plugin/test_torch_ddp_plugin.py

@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def check_torch_ddp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)

2
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py

@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3'

4
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py

@ -47,7 +47,7 @@ def test_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()
@ -60,7 +60,7 @@ def test_torch_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
model = model_fn()
output = model(**data)

2
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py

@ -56,7 +56,7 @@ def test_timm_models():
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}

2
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py

@ -16,7 +16,7 @@ def test_torchaudio_models():
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
model = model_fn()
trace_and_compare(model,
data_gen_fn,

2
tests/test_lazy/lazy_init_utils.py

@ -60,7 +60,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn:
def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
ctx = LazyInitContext(tensor_cls=_MyTensor)

2
tests/test_lazy/test_distribute.py

@ -78,7 +78,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue
print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx:
model = model_fn()

0
tests/test_shardformer/__init__.py

0
tests/test_shardformer/test_model/__init__.py

38
tests/test_shardformer/test_model/_utils.py

@ -0,0 +1,38 @@
import copy
from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(world_size, model_fn):
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss

78
tests/test_shardformer/test_model/test_shard_llama.py

@ -1,64 +1,22 @@
import copy
import os
import random
import pytest
import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
def build_model(world_size, model_fn):
# create new model
config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128)
org_model = model_fn(config).cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
del tokenized_input["token_type_ids"]
del tokenized_input["attention_mask"]
# switch to train mode
org_model.train()
sharded_model.train()
if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
org_output = org_model(**tokenized_input)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, LlamaForCausalLM):
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
org_output = org_model(**tokenized_input)
org_loss = org_output.loss
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.loss
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
# forward check
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
# run backward
@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model):
shard_loss.backward()
# check grad
if isinstance(org_model, LlamaModel):
llama_model = org_model
shard_llama_model = sharded_model
else:
if hasattr(org_model, 'model'):
llama_model = org_model.model
shard_llama_model = sharded_model.model
else:
llama_model = org_model
shard_llama_model = sharded_model
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
@ -89,17 +47,11 @@ def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [
LlamaModel,
# LlamaForCausalLM,
# TODO: do not work yet
# LlamaForSequenceClassification
]
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for model_fn in model_list:
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

75
tests/test_shardformer/test_model/test_shard_t5.py

@ -1,64 +1,20 @@
import copy
import os
import pytest
import torch
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
def build_model(world_size, model_fn):
config = T5Config(decoder_start_token_id=0)
config.dropout_rate = 0
org_model = model_fn(config=config).to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size)
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward_backward(org_model, sharded_model):
# prepare input
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
# switch to train mode
org_model.train()
sharded_model.train()
if isinstance(org_model, T5ForConditionalGeneration):
org_output = org_model(input_ids=input_ids, labels=labels)
org_loss = org_output.loss
shard_output = sharded_model(input_ids=input_ids, labels=labels)
shard_loss = shard_output.loss
elif isinstance(org_model, T5Model):
decoder_input_ids = org_model._shift_right(input_ids)
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, T5EncoderModel):
org_output = org_model(input_ids=input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids)
shard_loss = shard_output.last_hidden_state.mean()
# key is sharded, so we ignore
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
# the value "past_key_values" is sharded, so we ignore
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
@ -81,18 +37,15 @@ def check_forward_backward(org_model, sharded_model):
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_fn_list = [
T5Model,
T5ForConditionalGeneration,
T5EncoderModel,
]
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for model_fn in model_fn_list:
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model)
torch.cuda.empty_cache()
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist

Loading…
Cancel
Save