mirror of https://github.com/hpcaitech/ColossalAI
[shardformer]update t5 tests for using all optimizations. (#4407)
* [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer]update t5 to use all optimizationspull/4445/head
parent
1edc9b5fb3
commit
108e54a0b4
|
@ -30,7 +30,7 @@
|
|||
|
||||
### Quick Start
|
||||
|
||||
The sample API usage is given below(If you enable the use of flash attention, please install xformers.):
|
||||
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):
|
||||
|
||||
```python
|
||||
from colossalai.shardformer import ShardConfig, Shard
|
||||
|
|
|
@ -16,8 +16,8 @@ def data_gen_for_encoder_only():
|
|||
# config = T5Config(decoder_start_token_id=0)
|
||||
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ def data_gen_for_conditional_generation():
|
|||
#
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
|
@ -35,7 +35,7 @@ def data_gen_for_t5_model():
|
|||
# decoder_inputs_ids is obtained with the following code
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
|
|
@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
t5 = org_model
|
||||
|
@ -50,14 +54,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
|
||||
|
||||
# check weights and gradients
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0)
|
||||
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
|
||||
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -66,23 +78,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': True
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1,
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1,
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 4,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_t5_test(test_config):
|
||||
|
@ -93,7 +111,6 @@ def run_t5_test(test_config):
|
|||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
||||
|
|
Loading…
Reference in New Issue