From 4fa6b9509c7dfe44c4e99188a811255a848f8dbf Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 18 Sep 2024 10:09:01 +0800 Subject: [PATCH] [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063) --- colossalai/shardformer/modeling/deepseek.py | 13 +++++++++++++ .../test_model/test_shard_deepseek.py | 10 ++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 4b1b82b7c..7bcdf6fc9 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -109,6 +109,19 @@ class EPDeepseekMoE(ParallelModule): for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) + if self.config.n_shared_experts is not None: + self.shared_experts.gate_proj = Linear1D_Col.from_native_module( + self.shared_experts.gate_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + + self.shared_experts.up_proj = Linear1D_Col.from_native_module( + self.shared_experts.up_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + + self.shared_experts.down_proj = Linear1D_Row.from_native_module( + self.shared_experts.down_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + @staticmethod def from_native_module( module, diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index d782a2a09..4b92dbdee 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -20,14 +20,15 @@ from tests.test_moe.moe_utils import assert_loose_close, check_model_equal NUM_BATCH = 8 NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4 NUM_LAYERS = 4 -HIDDEN_SIZE_PER_HEAD = 4 +HIDDEN_SIZE_PER_HEAD = 8 NUM_HEADS = 8 TOP_K = 2 -def run_deepseek_commom(config: Tuple[int, ...]): +def run_deepseek_commom(parallel_config: Tuple[int, ...]): Randomizer.reset_index() - stage, ep_size, pp_size, tp_size, sp_size = config + print(f"rank {dist.get_rank()} testing {parallel_config}") + stage, ep_size, pp_size, tp_size, sp_size = parallel_config world_size = dist.get_world_size() rank = dist.get_rank() dtype, precision = torch.bfloat16, "bf16" @@ -65,6 +66,7 @@ def run_deepseek_commom(config: Tuple[int, ...]): attn_implementation="flash_attention_2", torch_dtype="float16", n_routed_experts=NUM_EXPERTS, + n_shared_experts=2, num_experts_per_tok=TOP_K, trust_remote_code=True, ) @@ -159,7 +161,7 @@ def run_deepseek_commom(config: Tuple[int, ...]): if rank == world_size - 1: shutil.rmtree(model_dir) - print(f"rank {dist.get_rank()} test passed") + print(f"rank {dist.get_rank()} passed {parallel_config}") @parameterize(