From 9c8b9992913773dadee1601699c6a7c08f24c1f3 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 8 Aug 2023 16:46:14 +0800 Subject: [PATCH] modified: internlm/model/moe.py modified: internlm/moe/sharded_moe.py modified: internlm/utils/parallel.py --- internlm/model/moe.py | 4 ++-- internlm/moe/sharded_moe.py | 10 +++++----- internlm/utils/parallel.py | 10 ++++++++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index a145b97..180d829 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -81,8 +81,8 @@ class MoE(torch.nn.Module): self.num_local_experts = num_experts // self.ep_size logger.info( - f"""Creating MoE layer with num_experts: {num_experts} | num_local_experts: - {self.num_local_experts} | expert_parallel_size: {self.ep_size}""" + f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:" + f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}" ) assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 7af036c..1ae68e0 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -226,12 +226,12 @@ def top1gating( else: mask1_rand = mask1 - assert ( - logits.shape[0] >= min_capacity - ), """No. of tokens (batch-size) should be greater than min_capacity. - Either set min_capacity to 0 or increase your batch size.""" + assert logits.shape[0] >= min_capacity, ( + "No. of tokens (batch-size) should be greater than min_capacity." + "Either set min_capacity to 0 or increase your batch size." + ) - top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index + top_idx = _top_idx(mask1_rand, capacity) # token index new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) mask1 = new_mask1 diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index cffcdc1..63b190d 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -5,6 +5,7 @@ import torch.distributed as dist from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.moe import is_moe_param def is_model_parallel_parameter(p): @@ -20,8 +21,13 @@ def sync_model_param(model, parallel_mode): """ if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: for param in model.parameters(): - ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) + if is_moe_param(param): + # TODO: moe expert param need to sync in expert data parallel group + # now we do not support expert data parallel + pass + else: + ranks = gpc.get_ranks_in_group(parallel_mode) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) def sync_model_param_within_tp(model):