mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] update tests for all optimization (#4413)
[shardformer] update tests for all optimizationpull/4445/head
parent
7711bd524a
commit
1edc9b5fb3
|
@ -1048,9 +1048,12 @@ def get_bert_flash_attention_forward():
|
|||
final_attention_mask = final_attention_mask * scale + attention_mask
|
||||
else:
|
||||
final_attention_mask = attention_mask
|
||||
|
||||
if final_attention_mask is not None:
|
||||
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
|
||||
tgt_len = key_layer.size()[2]
|
||||
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
|
||||
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len,
|
||||
tgt_len).contiguous()
|
||||
|
||||
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
|
||||
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
|
|
@ -69,21 +69,30 @@ def data_gen_for_mcq():
|
|||
# data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442,
|
||||
1012, 102, 102
|
||||
],
|
||||
[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
|
||||
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0, 0
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0
|
||||
]]])
|
||||
token_type_ids = torch.tensor(
|
||||
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0]]])
|
||||
attention_mask = torch.tensor(
|
||||
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0]]])
|
||||
token_type_ids = torch.tensor([[[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1
|
||||
],
|
||||
[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
|
||||
]]])
|
||||
attention_mask = torch.tensor([[[
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1
|
||||
],
|
||||
[
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
|
||||
]]])
|
||||
labels = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
|
|
@ -36,10 +36,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
tp_group = booster.plugin.tp_group
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
bert = org_model
|
||||
|
@ -51,17 +55,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
col_layer_for_check = ['encoder.layer[0].output.dense']
|
||||
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
|
||||
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
|
||||
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
|
||||
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 5e-3, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
|
||||
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -70,23 +82,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': True
|
||||
'use_lazy_init': True,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1,
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}])
|
||||
def run_bert_test(test_config):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
test_config['precision'] = 'float'
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
|
Loading…
Reference in New Issue