From 7077d38d5a5b9243521f44d10d4dabc012044dbb Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 18 Jul 2024 13:36:18 +0000 Subject: [PATCH] [moe] finalize test (no pp) --- .../plugin/moe_hybrid_parallel_plugin.py | 18 ++++++++----- tests/test_moe/modelling/test_mixtral.py | 27 ++++++++++++------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index d4226b108..31b346b10 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -109,6 +109,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): super().__init__(*args, **kwargs) + if ep_size <= 1: + raise ValueError("Use HybridParallelPlugin when ep_size <= 1") + self.ep_size = ep_size self.moe_tp_size = moe_tp_size @@ -128,12 +131,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.ddp_config["find_unused_parameters"] = True if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + # TODO it might make sense to support non-moe with tp on but moe with tp off raise ValueError( - f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0" + f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0" ) - # set ep_group after super().__init__() - # TODO do it in a better way + # set param group in shard config self.shard_config.ep_group = self.ep_group self.shard_config.moe_dp_group = self.moe_dp_group self.shard_config.moe_tp_group = self.moe_tp_group @@ -149,9 +152,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # when sequence parallelism is enabled, ep_group reuses sp_group if self.ep_size != self.sp_size: raise ValueError( - f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled" + f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled" ) + # since we are reusing sp_group, moe_dp_group will be derived as dp_group self.moe_dp_size = self.dp_size self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) @@ -165,7 +169,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): else: self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size) - if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size: + if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size: raise ValueError( f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" ) @@ -214,8 +218,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.moe_tp_group = group if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group): - # NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable - # this assertion implies that dp_size == moe_dp_size * ep_size + # NOTE: different tp settings between moe and non moe param are complex to handle + # we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size raise NotImplementedError( f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size" ) diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py index fe13b5b30..69d9fa5d4 100644 --- a/tests/test_moe/modelling/test_mixtral.py +++ b/tests/test_moe/modelling/test_mixtral.py @@ -18,28 +18,34 @@ from tests.test_moe.moe_utils import loose_close from tests.test_moe.test_moe_checkpoint import check_model_equal NUM_BATCH = 4 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4 HIDDEN_SIZE_PER_HEAD = 4 NUM_HEADS = 4 TOP_K = 1 -@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)]) +@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)]) def run_zero_with_original_model(config: Tuple[int, ...]): - stage, ep_size, tp_size = config - dtype, precision = torch.float16, "fp16" + ep_size, stage, dp_size, pp_size, tp_size, sp_size = config + print(config) rank = torch.distributed.get_rank() + dtype, precision = torch.float16, "fp16" torch.cuda.set_device(dist.get_rank()) plugin = MoeHybridParallelPlugin( - pp_size=1, + pp_size=pp_size, + num_microbatches=pp_size, tp_size=tp_size, - moe_tp_size=tp_size, + sp_size=sp_size, ep_size=ep_size, + moe_tp_size=tp_size, zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, overlap_communication=False, initial_scale=1, precision=precision, + find_unused_parameters=True, ) booster = Booster(plugin=plugin) @@ -53,6 +59,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): num_key_value_heads=NUM_HEADS, num_local_experts=NUM_EXPERTS, num_experts_per_tok=TOP_K, + attn_implementation="flash_attention_2", ) torch_model = MixtralModel(config).to(dtype).cuda() @@ -72,7 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]): input_data = torch.rand( NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True ).cuda() - dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input + + dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input + dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() zero_optimizer.backward(zero_output) @@ -124,11 +133,11 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("world_size", [8]) @rerun_if_address_is_in_use() def test_mistral(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_mistral(world_size=4) + test_mistral(world_size=8)