From 513ebb9c3a38778b2334f7cb386e02ac6beb8b64 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Mon, 18 Dec 2023 14:39:42 +0800 Subject: [PATCH] fix(moe): fix moe zero mode bug (#548) * fix moe zero mode bugs * update moe config to fit training on 8 GPU --- configs/7B_MoE4_sft.py | 2 +- internlm/solver/optimizer/hybrid_zero_optim.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index cc94cdc..0672422 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -141,7 +141,7 @@ model = dict( layer_norm_epsilon=1e-5, use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. - num_experts=8, + num_experts=4, moe_use_residual=False, moe_gate_k=2, ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index eb7aae3..c4b87d7 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -150,7 +150,7 @@ class HybridZeroOptimizer(BaseOptimizer): # if zero is used, expert dp group will use ParallelMode.EXPERT_DATA as the real zero mode zero_mode = ( ParallelMode.ZERO1 - if param_group["dp_mode"] == gpc.get_world_size(ParallelMode.ZERO1) == 1 or ParallelMode.DATA + if gpc.get_world_size(ParallelMode.ZERO1) == 1 or param_group["dp_mode"] == ParallelMode.DATA else ParallelMode.EXPERT_DATA ) self._zero_local_rank.append(gpc.get_local_rank(zero_mode))