mirror of https://github.com/InternLM/InternLM
modified: internlm/model/moe.py
modified: internlm/moe/sharded_moe.py modified: internlm/utils/parallel.pypull/375/head
parent
84476833f3
commit
9c8b999291
|
@ -81,8 +81,8 @@ class MoE(torch.nn.Module):
|
||||||
self.num_local_experts = num_experts // self.ep_size
|
self.num_local_experts = num_experts // self.ep_size
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"""Creating MoE layer with num_experts: {num_experts} | num_local_experts:
|
f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
|
||||||
{self.num_local_experts} | expert_parallel_size: {self.ep_size}"""
|
f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
||||||
|
|
|
@ -226,12 +226,12 @@ def top1gating(
|
||||||
else:
|
else:
|
||||||
mask1_rand = mask1
|
mask1_rand = mask1
|
||||||
|
|
||||||
assert (
|
assert logits.shape[0] >= min_capacity, (
|
||||||
logits.shape[0] >= min_capacity
|
"No. of tokens (batch-size) should be greater than min_capacity."
|
||||||
), """No. of tokens (batch-size) should be greater than min_capacity.
|
"Either set min_capacity to 0 or increase your batch size."
|
||||||
Either set min_capacity to 0 or increase your batch size."""
|
)
|
||||||
|
|
||||||
top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index
|
top_idx = _top_idx(mask1_rand, capacity) # token index
|
||||||
|
|
||||||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
||||||
mask1 = new_mask1
|
mask1 = new_mask1
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.model.moe import is_moe_param
|
||||||
|
|
||||||
|
|
||||||
def is_model_parallel_parameter(p):
|
def is_model_parallel_parameter(p):
|
||||||
|
@ -20,8 +21,13 @@ def sync_model_param(model, parallel_mode):
|
||||||
"""
|
"""
|
||||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
if is_moe_param(param):
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
# TODO: moe expert param need to sync in expert data parallel group
|
||||||
|
# now we do not support expert data parallel
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||||
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||||
|
|
||||||
|
|
||||||
def sync_model_param_within_tp(model):
|
def sync_model_param_within_tp(model):
|
||||||
|
|
Loading…
Reference in New Issue