From c3854f924acde7506b6c9df8a15809cbd4c466d1 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Mon, 8 Jan 2024 14:33:19 +0800 Subject: [PATCH] refactor code --- internlm/model/modeling_moe.py | 2 +- internlm/model/moe.py | 51 ++++++++++------------------------ internlm/moe/base_moe.py | 17 +++++++++--- internlm/moe/sharded_moe.py | 51 ++++++++++++++++++++++++++++++++-- 4 files changed, 77 insertions(+), 44 deletions(-) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a8..a87e306 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -175,7 +175,7 @@ class PackedFlashBaseLayer1D(nn.Module): hidden_size=hidden_size, num_experts=num_experts, ep_size=ep_size, - k=moe_gate_k, + topk=moe_gate_k, capacity_factor=moe_capacity_factor, eval_capacity_factor=moe_eval_capacity_factor, min_capacity=moe_min_capacity, diff --git a/internlm/model/moe.py b/internlm/model/moe.py index dc63de0..63dcb63 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -5,8 +5,7 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.linear import FeedForward -from internlm.moe.experts import Experts -from internlm.moe.sharded_moe import GShardMOELayer, TopKGate +from internlm.moe.sharded_moe import GShardMOELayer from internlm.utils.logger import get_logger # global llm logger @@ -40,7 +39,7 @@ class MoE(torch.nn.Module): hidden_size, num_experts=1, ep_size=1, - k=1, + topk=1, capacity_factor=1.0, eval_capacity_factor=1.0, min_capacity=4, @@ -66,43 +65,21 @@ class MoE(torch.nn.Module): "Unsupported noisy_gate_policy: " + noisy_gate_policy ) - # for elastic expert paralle, experts may have multiple groups - expert_group_name = f"moe_ep_size_{self.ep_size}" - if expert_group_name not in gpc.expert_parallel_group_names: - gpc.expert_parallel_group_names.append(expert_group_name) - experts = torch.nn.ModuleList( - [ - FeedForward( - hidden_size, - int(hidden_size * gpc.config.model.mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(ParallelMode.TENSOR), - bias=False, - device=device, - dtype=dtype, - ) - for _ in range(self.num_local_experts) - ] - ) - experts = Experts(experts, self.num_local_experts, expert_group_name) - if using_default_moe: self.moe_layer = GShardMOELayer( - TopKGate( - hidden_size, - num_experts, - k, - capacity_factor, - eval_capacity_factor, - min_capacity, - noisy_gate_policy, - drop_tokens, - use_rts, - ), - experts, + hidden_size, gpc.get_group(ParallelMode.EXPERT), - self.ep_size, - self.num_local_experts, + ep_size, + num_experts, + topk, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts, + device, + dtype, ) # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence diff --git a/internlm/moe/base_moe.py b/internlm/moe/base_moe.py index 48f01e9..459771f 100644 --- a/internlm/moe/base_moe.py +++ b/internlm/moe/base_moe.py @@ -1,7 +1,10 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from torch import Tensor -from torch.nn import Module +from torch.nn import Module, ModuleList + +from internlm.core.context import global_context as gpc +from internlm.moe.experts import Experts if TYPE_CHECKING: Base = Module[Tensor] @@ -14,10 +17,16 @@ class BaseMoELayer(Base): Base MoE Layer. """ - def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None: + def __init__( + self, gate: Module, experts: Union[Module, ModuleList], ep_group, ep_size: int, num_local_experts: int + ) -> None: super().__init__() + # for elastic expert paralle, experts may have multiple groups + expert_group_name = f"moe_ep_size_{ep_size}" + if expert_group_name not in gpc.expert_parallel_group_names: + gpc.expert_parallel_group_names.append(expert_group_name) self.gate = gate - self.experts = experts + self.experts = Experts(experts, num_local_experts, expert_group_name) self.ep_group = ep_group self.ep_size = ep_size self.num_local_experts = num_local_experts diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 631a9da..1a9fd1c 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -12,6 +12,9 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.model.linear import FeedForward from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -379,8 +382,52 @@ class GShardMOELayer(BaseMoELayer): expert network """ - def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None: - super().__init__(gate, experts, ep_group, ep_size, num_local_experts) + def __init__( + self, + hidden_size, + ep_group, + ep_size: int, + num_experts: int, + topk, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts, + device=None, + dtype=None, + ) -> None: + super().__init__( + TopKGate( + hidden_size, + num_experts, + topk, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts, + ), + torch.nn.ModuleList( + [ + FeedForward( + hidden_size, + int(hidden_size * gpc.config.model.mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + ) + for _ in range(num_experts // ep_size) + ] + ), + ep_group, + ep_size, + num_experts // ep_size, + ) self.time_falltoall = 0.0 self.time_salltoall = 0.0