From 376533a56411d3826df2a5b3aabc5471016496bf Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 28 Aug 2023 10:51:16 +0800 Subject: [PATCH] [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py --- colossalai/pipeline/schedule/one_f_one_b.py | 3 +- colossalai/zero/low_level/low_level_optim.py | 9 ++- .../test_model/test_shard_bert.py | 9 +++ .../test_model/test_shard_bloom.py | 9 +++ .../test_model/test_shard_gpt2.py | 9 +++ .../test_model/test_shard_llama.py | 9 +++ .../test_model/test_shard_opt.py | 9 +++ .../test_model/test_shard_t5.py | 9 +++ .../test_model/test_shard_vit.py | 11 ++- .../test_model/test_shard_whisper.py | 67 ++++++++++--------- 10 files changed, 109 insertions(+), 35 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index f5e4929aa..0058873c2 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -128,11 +128,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ micro_batch = self.load_micro_batch() - # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: accum_loss.add_(loss.detach()) @@ -158,7 +158,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # Retain the grad on the input_obj. tree_map(retain_grad, input_obj) - # Backward pass. if output_obj_grad is None: optimizer.backward(output_obj) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 64d6a5395..a1e85e5b9 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -316,7 +316,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def backward(self, loss, retain_graph=False): assert not(self._partition_grads and not self.require_grad_sync), \ "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" - if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) @@ -333,6 +332,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.zero_grad() + def backward_by_grad(self, tensor, grad): + # in lower stage which grad is transfered by higher stage + # we need to pass the optim state down. + if self.mixed_precision_mixin is not None: + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient @@ -358,7 +364,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 76f8c0541..a15645a7f 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 0e236fd47..590eff642 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 85d66e493..13458fc54 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 485d2685e..8dc6376bf 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 71483b752..939b2d555 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index a853f024d..cd3d3d673 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 0b092966c..d40058bb7 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'ViTModel': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model @@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() +#TODO: num_microbatch size = 2 inf loss @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6eaed7d37..356ed6405 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -112,37 +112,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() - +#TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 -#TODO fix WhisperForConditionalGeneration enable jit fused operator -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', -}]) +@parameterize( + 'test_config', + [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + # whisper is not supported fp16 for now. + ]) def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():