diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index be9d1c58b..abe491bfa 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -103,21 +103,21 @@ class OPTPolicy(Policy): target_key=OPTDecoderLayer) # use flash attention - # if self.shard_config.enable_flash_attention: - # self.append_or_create_method_replacement(description={ - # 'forward': get_opt_flash_attention_forward(), - # }, - # policy=policy, - # target_key=OPTAttention) + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement(description={ + 'forward': get_opt_flash_attention_forward(), + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator - # if self.shard_config.enable_jit_fused: - # self.append_or_create_method_replacement(description={ - # 'forward': get_jit_fused_opt_decoder_layer_forward(), - # 'dropout_add': get_jit_fused_dropout_add_func(), - # }, - # policy=policy, - # target_key=OPTDecoderLayer) + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_opt_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 192a1b847..92cbd3f72 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -184,24 +184,33 @@ class T5BasePolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[T5Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_t5_flash_attention_forward(), - }) + }, + policy=policy, + target_key=T5Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_T5_layer_ff_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerFF) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_self_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_cross_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=T5LayerCrossAttention) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index bffb624d0..5d496f08e 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -56,9 +56,6 @@ class WhisperPolicy(Policy): self.shard_config.enable_sequence_parallelism = False warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") - if self.shard_config.enable_jit_fused: - self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ @@ -212,6 +209,21 @@ class WhisperPolicy(Policy): policy=policy, target_key=WhisperAttention) + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperDecoderLayer) + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_encoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer) + return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 803afc48a..72bb2b025 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -237,6 +237,43 @@ def check_weight(org_model: Module, f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" +def get_grad_tensors_for_check(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None): + + grad_to_check = {} + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[:org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + + grad_to_check[suffix] = { + "org_grad": org_grad.float(), + "shard_grad": shard_grad.float(), + "rtol": rtol, + "atol": atol + } + + return grad_to_check + + +# used by sam/blip2 def check_grad(org_model: Module, sharded_model: Module, layer_suffix: List[str], @@ -275,3 +312,18 @@ def unwrap_model(module: Module, if module.__class__.__name__ == base_model_class_name: return module return getattr(module, base_model_attribute_name, None) + + +def check_all_grad_tensors(check_tensors): + """ + "org_grad": tensor to be compared from the original model + "shard_grad": tensor to be compared from the sharded model + """ + for suffix, check_info in check_tensors.items(): + org_grad = check_info["org_grad"] + shard_grad = check_info["shard_grad"] + rtol = check_info["rtol"] + atol = check_info["atol"] + assert torch.allclose( + org_grad, shard_grad, atol=atol, rtol=rtol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index a15645a7f..61881a1f9 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -33,18 +34,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, output_transform_fn, criterion, booster) + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BertModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) bert = unwrap_model(org_model, 'BertModel', 'bert') sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') @@ -52,17 +44,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() + col_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'BertModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 else: @@ -70,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager is None or stage_manager.is_first_stage(): check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 590eff642..f7ab94bc9 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,35 +37,54 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BloomModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model bloom = unwrap_model(org_model, 'BloomModel', 'transformer') sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') - # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] col_layer_for_check = ['h[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-5 else: atol, rtol = 5e-3, 5e-3 - check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'BloomModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -72,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index a8957d8d3..c5a3e68f7 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,51 +37,57 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'ChatGLMModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') - # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(chatglm_model, - shard_chatglm_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - - check_grad(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + col_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'ChatGLMModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -95,6 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 13458fc54..44914721c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,18 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'GPT2Model': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') @@ -55,18 +44,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] - # check grad + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() + col_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'GPT2Model': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -74,6 +94,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8dc6376bf..c9d5d3d08 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -12,10 +12,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -41,49 +42,56 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'LlamaModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model llama_model = unwrap_model(org_model, 'LlamaModel', 'model') shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') - # check grad + row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-4 else: atol, rtol = 5e-3, 5e-3 - check_grad(llama_model, - shard_llama_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'LlamaModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -98,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 939b2d555..8c0432b37 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -11,10 +11,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -40,49 +41,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'OPTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model opt_model = unwrap_model(org_model, 'OPTModel', 'model') shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') - # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: - atol, rtol = 3e-2, 3e-2 - check_grad(opt_model, - shard_opt_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + atol, rtol = 4e-2, 4e-2 + row_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'OPTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-3, 1e-3 @@ -97,6 +104,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index cd3d3d673..29367031e 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -37,42 +38,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ != 'T5ForConditionalGeneration': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model t5 = unwrap_model(org_model) sharded_t5 = unwrap_model(sharded_model) row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - # check grad + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(t5, + sharded_t5, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ != 'T5ForConditionalGeneration': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 + atol, rtol = 5e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index d40058bb7..2980c6eea 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,17 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'ViTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model vit_model = unwrap_model(org_model, 'ViTModel', 'vit') shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') @@ -54,31 +44,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(vit_model, - shard_vit_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'ViTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -93,6 +101,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 356ed6405..a55753018 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -15,10 +15,11 @@ from colossalai.testing import ( from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, ) @@ -41,18 +42,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 2e-4, 2e-4 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'WhisperModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model @@ -75,19 +64,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, #'decoder.layers[0].self_attn.out_proj' ] - # check weights and gradients + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1) + col_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-4, 5e-4 else: @@ -110,8 +128,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() + #TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 @parameterize(