mirror of https://github.com/InternLM/InternLM
modified: internlm/model/moe.py
modified: internlm/moe/sharded_moe.py modified: internlm/utils/parallel.pypull/375/head
parent
84476833f3
commit
9c8b999291
|
@ -81,8 +81,8 @@ class MoE(torch.nn.Module):
|
|||
self.num_local_experts = num_experts // self.ep_size
|
||||
|
||||
logger.info(
|
||||
f"""Creating MoE layer with num_experts: {num_experts} | num_local_experts:
|
||||
{self.num_local_experts} | expert_parallel_size: {self.ep_size}"""
|
||||
f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
|
||||
f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
|
||||
)
|
||||
|
||||
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
||||
|
|
|
@ -226,12 +226,12 @@ def top1gating(
|
|||
else:
|
||||
mask1_rand = mask1
|
||||
|
||||
assert (
|
||||
logits.shape[0] >= min_capacity
|
||||
), """No. of tokens (batch-size) should be greater than min_capacity.
|
||||
Either set min_capacity to 0 or increase your batch size."""
|
||||
assert logits.shape[0] >= min_capacity, (
|
||||
"No. of tokens (batch-size) should be greater than min_capacity."
|
||||
"Either set min_capacity to 0 or increase your batch size."
|
||||
)
|
||||
|
||||
top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index
|
||||
top_idx = _top_idx(mask1_rand, capacity) # token index
|
||||
|
||||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
||||
mask1 = new_mask1
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.distributed as dist
|
|||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.moe import is_moe_param
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
|
@ -20,8 +21,13 @@ def sync_model_param(model, parallel_mode):
|
|||
"""
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
if is_moe_param(param):
|
||||
# TODO: moe expert param need to sync in expert data parallel group
|
||||
# now we do not support expert data parallel
|
||||
pass
|
||||
else:
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def sync_model_param_within_tp(model):
|
||||
|
|
Loading…
Reference in New Issue