mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support pp+tp+zero1 tests (#4531)
* [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1pull/4544/head
parent
d367b88785
commit
ec18fc7340
|
@ -333,12 +333,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self.zero_grad()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
# in lower stage which grad is transfered by higher stage
|
||||
# we need to pass the optim state down.
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||
torch.autograd.backward(tensor, grad)
|
||||
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
self._reduce_grad(self._partition_grads)
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
||||
|
|
|
@ -163,6 +163,15 @@ def run_bert_test(test_config):
|
|||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
|
|
|
@ -165,6 +165,16 @@ def run_bloom_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_bloom_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
|
|
|
@ -165,6 +165,16 @@ def run_chatglm_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_chatglm_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
|
|
|
@ -183,6 +183,16 @@ def run_gpt2_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_3d_test(test_config):
|
||||
|
|
|
@ -185,6 +185,16 @@ def run_llama_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_llama_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
|
|
@ -174,6 +174,16 @@ def run_opt_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_opt_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
|
|
@ -170,6 +170,16 @@ def run_t5_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp16',
|
||||
'zero_stage': 1,
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_t5_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
|
|
@ -176,6 +176,15 @@ def run_vit_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_vit_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
|
|
|
@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
# check weights
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 5e-4, 5e-4
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
@ -195,6 +195,15 @@ def run_whisper_test(test_config):
|
|||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_whisper_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||
|
|
Loading…
Reference in New Issue