diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c9da9d32e..c51df07f6 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -210,7 +210,7 @@ def check_weight(org_model: Module, if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -219,7 +219,7 @@ def check_weight(org_model: Module, print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ - f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" def check_grad(org_model: Module, @@ -236,9 +236,7 @@ def check_grad(org_model: Module, shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [ - torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) - ] + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3ac8fa26d..274cfaa39 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( org_model, @@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'GPT2Model': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - # check loss check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) def unwrap(module): @@ -92,13 +90,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'num_microbatches': 4, 'enable_all_optimization': True, 'use_lazy_init': True, - 'precision': 'fp32', + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'enable_all_optimization': True, - 'use_lazy_init': False, + 'use_lazy_init': True, 'precision': 'fp16', 'initial_scale': 1, }, { @@ -112,7 +111,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_gpt2_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # TODO: check and debug TP+AMP sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')