From 108e54a0b46bf7b103110842c303ebc26318efff Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 14 Aug 2023 15:49:13 +0800 Subject: [PATCH] [shardformer]update t5 tests for using all optimizations. (#4407) * [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer]update t5 to use all optimizations --- colossalai/shardformer/README.md | 2 +- tests/kit/model_zoo/transformers/t5.py | 8 ++-- .../test_model/test_shard_t5.py | 39 +++++++++++++------ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 5d00e606d..7dc15f0a0 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -30,7 +30,7 @@ ### Quick Start -The sample API usage is given below(If you enable the use of flash attention, please install xformers.): +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): ```python from colossalai.shardformer import ShardConfig, Shard diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 435cb6f46..175d48963 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -16,8 +16,8 @@ def data_gen_for_encoder_only(): # config = T5Config(decoder_start_token_id=0) # tokenizer = T5Tokenizer.from_pretrained("t5-small") # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids - input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -26,7 +26,7 @@ def data_gen_for_conditional_generation(): # # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() - labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long() data['labels'] = labels return data @@ -35,7 +35,7 @@ def data_gen_for_t5_model(): # decoder_inputs_ids is obtained with the following code # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() - decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long() data['decoder_input_ids'] = decoder_input_ids return data diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index d807ffa06..fb065b422 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ != 'T5ForConditionalGeneration': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model t5 = org_model @@ -50,14 +54,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] # check weights and gradients + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0) + check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) torch.cuda.empty_cache() @@ -66,23 +78,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 2, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_t5_test(test_config): @@ -93,7 +111,6 @@ def run_t5_test(test_config): # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():