[shardformer] zero1+pp and the corresponding tests (#4517)

* pause

* finish pp+zero1

* Update test_shard_vit.py
pull/4526/head
Jianghai 2023-08-28 10:51:16 +08:00 committed by GitHub
parent 44eab2b27f
commit 376533a564
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 109 additions and 35 deletions

View File

@ -128,11 +128,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch()
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
@ -158,7 +158,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)

View File

@ -316,7 +316,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
@ -333,6 +332,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad()
def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
@ -358,7 +364,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if not self.require_grad_sync:
return

View File

@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bert_test(test_config):

View File

@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bloom_test(test_config):

View File

@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):

View File

@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_llama_test(test_config):

View File

@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_opt_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')

View File

@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_t5_test(test_config):

View File

@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()
#TODO: num_microbatch size = 2 inf loss
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_vit_test(test_config):

View File

@ -112,37 +112,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()
#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODOjianghai) fix fp16
#TODO fix WhisperForConditionalGeneration enable jit fused operator
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
}])
@parameterize(
'test_config',
[
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
},
# whisper is not supported fp16 for now.
])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():