modified: internlm/model/moe.py

modified:   internlm/moe/sharded_moe.py
	modified:   internlm/utils/parallel.py
pull/375/head
Wenwen Qu 2023-08-08 16:46:14 +08:00
parent 84476833f3
commit 9c8b999291
3 changed files with 15 additions and 9 deletions

View File

@ -81,8 +81,8 @@ class MoE(torch.nn.Module):
self.num_local_experts = num_experts // self.ep_size self.num_local_experts = num_experts // self.ep_size
logger.info( logger.info(
f"""Creating MoE layer with num_experts: {num_experts} | num_local_experts: f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
{self.num_local_experts} | expert_parallel_size: {self.ep_size}""" f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
) )
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (

View File

@ -226,12 +226,12 @@ def top1gating(
else: else:
mask1_rand = mask1 mask1_rand = mask1
assert ( assert logits.shape[0] >= min_capacity, (
logits.shape[0] >= min_capacity "No. of tokens (batch-size) should be greater than min_capacity."
), """No. of tokens (batch-size) should be greater than min_capacity. "Either set min_capacity to 0 or increase your batch size."
Either set min_capacity to 0 or increase your batch size.""" )
top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index top_idx = _top_idx(mask1_rand, capacity) # token index
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1 mask1 = new_mask1

View File

@ -5,6 +5,7 @@ import torch.distributed as dist
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.moe import is_moe_param
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
@ -20,8 +21,13 @@ def sync_model_param(model, parallel_mode):
""" """
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters(): for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode) if is_moe_param(param):
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) # TODO: moe expert param need to sync in expert data parallel group
# now we do not support expert data parallel
pass
else:
ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
def sync_model_param_within_tp(model): def sync_model_param_within_tp(model):