diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 029ac36cd..d30ce5ea8 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -245,7 +245,6 @@ class MixtralPipelineForwards: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - print("input_ids", input_ids.shape) batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 187c642da..46da4522f 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -17,7 +17,7 @@ from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import assert_loose_close, check_model_equal NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4000, 2 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2 NUM_LAYERS = 4 HIDDEN_SIZE_PER_HEAD = 4 NUM_HEADS = 4 diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_skip.py b/tests/test_shardformer/test_model/test_shard_deepseek_skip.py deleted file mode 100644 index fe834a4f6..000000000 --- a/tests/test_shardformer/test_model/test_shard_deepseek_skip.py +++ /dev/null @@ -1,232 +0,0 @@ -# modified from test_shard_mistral.py -import os - -import pytest -import torch -import torch.distributed as dist -from torch.testing import assert_close - -import colossalai -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import ( - build_model_from_hybrid_plugin, - check_all_grad_tensors, - check_loss, - check_output_hidden_state, - check_weight, - get_grad_tensors_for_check, - run_forward_backward_with_hybrid_plugin, - unwrap_model, -) - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" - - -def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): - # TODO: SGD failed for full dp - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD - ) - - org_model = org_model.to(torch.float16) - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group - - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol) - - # unwrap model - mixtral_model = unwrap_model(org_model, "DeepseekModel", "model") - shard_mixtral_model = unwrap_model(sharded_model, "DeepseekModel", "model") - - row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] - col_layer_for_check = ["layers[0].self_attn.o_proj"] - - name_to_p = {n: p for n, p in mixtral_model.named_parameters()} - # Check the grad when using ZeRO-1 and ZeRO-2 - if ( - # booster.plugin.zero_stage in [1, 2] - booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" - ): - rank = dist.get_rank() - for n, p in shard_mixtral_model.named_parameters(): - zero_grad = sharded_optimizer.get_param_grad(p) - if name_to_p[n].grad is None: - name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) - continue - assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False) - - # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. - grads_to_check = {} - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config["precision"] == "fp32": - atol, rtol = 5e-5, 1e-4 - else: - atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check( - mixtral_model, - shard_mixtral_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False, - ) - col_layer_grads = get_grad_tensors_for_check( - mixtral_model, - shard_mixtral_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - grads_to_check.update(col_layer_grads) - grads_to_check.update(row_layer_grads) - - # check grads - check_all_grad_tensors(grads_to_check) - - for n, p in shard_mixtral_model.named_parameters(): - assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) - - # optimizer executes step - org_optimizer.step() - sharded_optimizer.step() - - for n, p in shard_mixtral_model.named_parameters(): - assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) - - # check weights - if stage_manager is None or stage_manager.is_first_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 2e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - try: - check_weight( - mixtral_model, - shard_mixtral_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - except Exception as e: - rank = dist.get_rank() - print(f"{rank=}, Failed config: {test_config}") - raise e - - torch.cuda.empty_cache() - - -@parameterize( - "test_config", - [ - # { - # "tp_size": 1, - # "pp_size": 1, - # "num_microbatches": 2, - # "ep_size": 2, - # "zero_stage": 0, - # "overlap_communication": False, - # "precision": "fp16", - # }, # [dp(4)] + [moe_dp(4)] - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "ep_size": 2, - # "zero_stage": 1, - # "overlap_communication": False, - # "precision": "fp32", - # }, # [dp(2) + pp(2)] + [moe_pp(2)] - # { - # "tp_size": 1, - # "pp_size": 2, - # "ep_size": 2, - # "num_microbatches": 2, - # "zero_stage": 1, - # "overlap_communication": False, - # "precision": "fp16", - # "initial_scale": 1, - # "find_unused_parameters": True, - # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "ep_size": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "zero_stage": 1, - "overlap_communication": False, - "precision": "fp16", - "initial_scale": 1, - "find_unused_parameters": True, - }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "ep_size": 2, - # "zero_stage": 0, - # "overlap_communication": False, - # "precision": "fp32", - # }, # [dp(4)] + [ep(2) + moe_tp(2)] - # { - # "tp_size": 1, - # "pp_size": 1, - # "ep_size": 4, - # "overlap_communication": False, - # "zero_stage": 0, - # "precision": "fp32" - # }, # full dp for non-moe and full ep for moe - ], -) -def run_deepseek_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek") - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - - clear_layout_converter() - Randomizer.reset_index() - torch.cuda.empty_cache() - - -def check_deepseek(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_deepseek_test() - - -@pytest.mark.skip("redundant") -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_mixtral(): - spawn(check_deepseek, 4) - - -if __name__ == "__main__": - test_mixtral()