Browse Source

[shardformer] support pp+tp+zero1 tests (#4531)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1
pull/4544/head
flybird11111 1 year ago committed by GitHub
parent
commit
ec18fc7340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 15
      colossalai/zero/low_level/low_level_optim.py
  2. 9
      tests/test_shardformer/test_model/test_shard_bert.py
  3. 10
      tests/test_shardformer/test_model/test_shard_bloom.py
  4. 10
      tests/test_shardformer/test_model/test_shard_chatglm2.py
  5. 10
      tests/test_shardformer/test_model/test_shard_gpt2.py
  6. 10
      tests/test_shardformer/test_model/test_shard_llama.py
  7. 10
      tests/test_shardformer/test_model/test_shard_opt.py
  8. 10
      tests/test_shardformer/test_model/test_shard_t5.py
  9. 9
      tests/test_shardformer/test_model/test_shard_vit.py
  10. 11
      tests/test_shardformer/test_model/test_shard_whisper.py

15
colossalai/zero/low_level/low_level_optim.py

@ -333,12 +333,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad()
def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self.zero_grad()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient

9
tests/test_shardformer/test_model/test_shard_bert.py

@ -163,6 +163,15 @@ def run_bert_test(test_config):
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])

10
tests/test_shardformer/test_model/test_shard_bloom.py

@ -165,6 +165,16 @@ def run_bloom_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')

10
tests/test_shardformer/test_model/test_shard_chatglm2.py

@ -165,6 +165,16 @@ def run_chatglm_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')

10
tests/test_shardformer/test_model/test_shard_gpt2.py

@ -183,6 +183,16 @@ def run_gpt2_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
@clear_cache_before_run()
def run_gpt2_3d_test(test_config):

10
tests/test_shardformer/test_model/test_shard_llama.py

@ -185,6 +185,16 @@ def run_llama_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')

10
tests/test_shardformer/test_model/test_shard_opt.py

@ -174,6 +174,16 @@ def run_opt_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')

10
tests/test_shardformer/test_model/test_shard_t5.py

@ -170,6 +170,16 @@ def run_t5_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')

9
tests/test_shardformer/test_model/test_shard_vit.py

@ -176,6 +176,15 @@ def run_vit_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')

11
tests/test_shardformer/test_model/test_shard_whisper.py

@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 5e-4
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
@ -195,6 +195,15 @@ def run_whisper_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')

Loading…
Cancel
Save