mirror of https://github.com/hpcaitech/ColossalAI
[shardformer]fix, test gpt2 for AMP+TP (#4403)
* [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer] gpt2 tests fixpull/4445/head
parent
7596e9ae08
commit
21e0a42fd1
|
@ -210,7 +210,7 @@ def check_weight(org_model: Module,
|
||||||
|
|
||||||
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||||
sharded_weight_list = [
|
sharded_weight_list = [
|
||||||
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
||||||
]
|
]
|
||||||
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
|
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
|
||||||
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
|
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
|
||||||
|
@ -219,7 +219,7 @@ def check_weight(org_model: Module,
|
||||||
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
||||||
|
|
||||||
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
|
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
|
||||||
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||||
|
|
||||||
|
|
||||||
def check_grad(org_model: Module,
|
def check_grad(org_model: Module,
|
||||||
|
@ -236,9 +236,7 @@ def check_grad(org_model: Module,
|
||||||
shard_weight = getattr_(sharded_model, suffix).weight
|
shard_weight = getattr_(sharded_model, suffix).weight
|
||||||
|
|
||||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad_list = [
|
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
|
||||||
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
|
||||||
]
|
|
||||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
|
||||||
org_loss, org_output, sharded_loss, sharded_output = \
|
org_loss, org_output, sharded_loss, sharded_output = \
|
||||||
run_forward_backward_with_hybrid_plugin(
|
run_forward_backward_with_hybrid_plugin(
|
||||||
org_model,
|
org_model,
|
||||||
|
@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
if org_model.__class__.__name__ == 'GPT2Model':
|
if org_model.__class__.__name__ == 'GPT2Model':
|
||||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# check loss
|
|
||||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
def unwrap(module):
|
def unwrap(module):
|
||||||
|
@ -92,13 +90,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
'num_microbatches': 4,
|
'num_microbatches': 4,
|
||||||
'enable_all_optimization': True,
|
'enable_all_optimization': True,
|
||||||
'use_lazy_init': True,
|
'use_lazy_init': True,
|
||||||
'precision': 'fp32',
|
'precision': 'fp16',
|
||||||
|
'initial_scale': 1,
|
||||||
}, {
|
}, {
|
||||||
'tp_size': 1,
|
'tp_size': 1,
|
||||||
'pp_size': 2,
|
'pp_size': 2,
|
||||||
'num_microbatches': 4,
|
'num_microbatches': 4,
|
||||||
'enable_all_optimization': True,
|
'enable_all_optimization': True,
|
||||||
'use_lazy_init': False,
|
'use_lazy_init': True,
|
||||||
'precision': 'fp16',
|
'precision': 'fp16',
|
||||||
'initial_scale': 1,
|
'initial_scale': 1,
|
||||||
}, {
|
}, {
|
||||||
|
@ -112,7 +111,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
def run_gpt2_test(test_config):
|
def run_gpt2_test(test_config):
|
||||||
|
|
||||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||||
# TODO: check and debug TP+AMP
|
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue