mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] zero1+pp and the corresponding tests (#4517)
* pause * finish pp+zero1 * Update test_shard_vit.pypull/4526/head
parent
44eab2b27f
commit
376533a564
|
@ -128,11 +128,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
micro_batch = self.load_micro_batch()
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
output_obj = model_forward(model, micro_batch, input_obj)
|
||||
if self.stage_manager.is_last_stage():
|
||||
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
|
@ -158,7 +158,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
# Retain the grad on the input_obj.
|
||||
tree_map(retain_grad, input_obj)
|
||||
|
||||
# Backward pass.
|
||||
if output_obj_grad is None:
|
||||
optimizer.backward(output_obj)
|
||||
|
|
|
@ -316,7 +316,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
||||
|
@ -333,6 +332,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
self.zero_grad()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
# in lower stage which grad is transfered by higher stage
|
||||
# we need to pass the optim state down.
|
||||
if self.mixed_precision_mixin is not None:
|
||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||
torch.autograd.backward(tensor, grad)
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
||||
|
@ -358,7 +364,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
|
|
|
@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_bert_test(test_config):
|
||||
|
||||
|
|
|
@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_bloom_test(test_config):
|
||||
|
||||
|
|
|
@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(test_config):
|
||||
|
|
|
@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
|
|
|
@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_opt_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
|
|
@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_t5_test(test_config):
|
||||
|
|
|
@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
|
@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
#TODO: num_microbatch size = 2 inf loss
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
|
@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_vit_test(test_config):
|
||||
|
||||
|
|
|
@ -112,37 +112,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
#TODO fix WhisperForConditionalGeneration enable jit fused operato
|
||||
# TODO(jianghai) fix fp16
|
||||
#TODO fix WhisperForConditionalGeneration enable jit fused operator
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 4,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}])
|
||||
@parameterize(
|
||||
'test_config',
|
||||
[
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
{
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
},
|
||||
{
|
||||
'tp_size': 1,
|
||||
'pp_size': 4,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
},
|
||||
# whisper is not supported fp16 for now.
|
||||
])
|
||||
def run_whisper_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
|
Loading…
Reference in New Issue