[shardformer] support pp+tp+zero1 tests (#4531)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1
pull/4544/head
flybird11111 2023-08-30 21:29:18 +08:00 committed by GitHub
parent d367b88785
commit ec18fc7340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 101 additions and 3 deletions

View File

@ -333,12 +333,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad() self.zero_grad()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage assert not(self._partition_grads and not self.require_grad_sync), \
# we need to pass the optim state down. "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad) torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self.zero_grad()
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
""" """
Set parameter gradients to zero. If set_to_none = True, gradient Set parameter gradients to zero. If set_to_none = True, gradient

View File

@ -163,6 +163,15 @@ def run_bert_test(test_config):
'enable_all_optimization': False, 'enable_all_optimization': False,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32', 'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1, 'initial_scale': 1,
}, },
]) ])

View File

@ -165,6 +165,16 @@ def run_bloom_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
]) ])
def run_bloom_3d_test(test_config): def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')

View File

@ -165,6 +165,16 @@ def run_chatglm_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
]) ])
def run_chatglm_3d_test(test_config): def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')

View File

@ -183,6 +183,16 @@ def run_gpt2_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
]) ])
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_3d_test(test_config): def run_gpt2_3d_test(test_config):

View File

@ -185,6 +185,16 @@ def run_llama_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
]) ])
def run_llama_3d_test(test_config): def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')

View File

@ -174,6 +174,16 @@ def run_opt_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
]) ])
def run_opt_3d_test(test_config): def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')

View File

@ -170,6 +170,16 @@ def run_t5_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
]) ])
def run_t5_3d_test(test_config): def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')

View File

@ -176,6 +176,15 @@ def run_vit_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
]) ])
def run_vit_3d_test(test_config): def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')

View File

@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights # check weights
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 5e-4 atol, rtol = 1e-3, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
@ -195,6 +195,15 @@ def run_whisper_test(test_config):
'precision': 'fp32', 'precision': 'fp32',
'initial_scale': 1, 'initial_scale': 1,
}, },
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
]) ])
def run_whisper_3d_test(test_config): def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')