diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index a87e306..4e609bc 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -86,14 +86,6 @@ class PackedFlashBaseLayer1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: int = 1, - moe_gate_k: int = 1, - moe_capacity_factor: float = 1.0, - moe_eval_capacity_factor: float = 1.0, - moe_min_capacity: int = 4, - moe_noisy_gate_policy: str = None, - moe_drop_tokens: bool = True, - moe_use_rts: bool = True, - moe_use_residual: bool = False, ): super().__init__() self.checkpoint = checkpoint @@ -131,14 +123,6 @@ class PackedFlashBaseLayer1D(nn.Module): set_fp32_attr_to_module(self.norm2) self.num_experts = num_experts - self.moe_gate_k = moe_gate_k - self.moe_capacity_factor = moe_capacity_factor - self.moe_eval_capacity_factor = moe_eval_capacity_factor - self.moe_min_capacity = moe_min_capacity - self.moe_noisy_gate_policy = moe_noisy_gate_policy - self.moe_drop_tokens = moe_drop_tokens - self.moe_use_rts = moe_use_rts - self.moe_use_residual = moe_use_residual ep_size = gpc.get_world_size(ParallelMode.EXPERT) if num_experts <= 1: # dense, not MoE if use_swiglu: @@ -175,14 +159,6 @@ class PackedFlashBaseLayer1D(nn.Module): hidden_size=hidden_size, num_experts=num_experts, ep_size=ep_size, - topk=moe_gate_k, - capacity_factor=moe_capacity_factor, - eval_capacity_factor=moe_eval_capacity_factor, - min_capacity=moe_min_capacity, - noisy_gate_policy=moe_noisy_gate_policy, - drop_tokens=moe_drop_tokens, - use_rts=moe_use_rts, - use_residual=moe_use_residual, device=device, dtype=dtype, ) @@ -357,14 +333,6 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: bool = 1, - moe_gate_k: int = 1, - moe_capacity_factor: float = 1.0, - moe_eval_capacity_factor: float = 1.0, - moe_min_capacity: int = 4, - moe_noisy_gate_policy: str = None, - moe_drop_tokens: bool = True, - moe_use_rts: bool = True, - moe_use_residual: bool = False, ): super().__init__() @@ -415,14 +383,6 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, num_experts=num_experts, - moe_gate_k=moe_gate_k, - moe_capacity_factor=moe_capacity_factor, - moe_eval_capacity_factor=moe_eval_capacity_factor, - moe_min_capacity=moe_min_capacity, - moe_noisy_gate_policy=moe_noisy_gate_policy, - moe_drop_tokens=moe_drop_tokens, - moe_use_rts=moe_use_rts, - moe_use_residual=moe_use_residual, ) for lid in range(num_layers) ] @@ -559,14 +519,14 @@ def build_model_with_moe_cfg( use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: int = 1, - moe_gate_k: int = 1, - moe_capacity_factor: float = 1.0, - moe_eval_capacity_factor: float = 1.0, - moe_min_capacity: int = 4, - moe_noisy_gate_policy: str = None, - moe_drop_tokens: bool = True, - moe_use_rts: bool = True, - moe_use_residual: bool = False, + moe_gate_k: int = 1, # pylint: disable=W0613 + moe_capacity_factor: float = 1.0, # pylint: disable=W0613 + moe_eval_capacity_factor: float = 1.0, # pylint: disable=W0613 + moe_min_capacity: int = 4, # pylint: disable=W0613 + moe_noisy_gate_policy: str = None, # pylint: disable=W0613 + moe_drop_tokens: bool = True, # pylint: disable=W0613 + moe_use_rts: bool = True, # pylint: disable=W0613 + moe_use_residual: bool = False, # pylint: disable=W0613 ): """ Build model with config. @@ -633,14 +593,6 @@ def build_model_with_moe_cfg( use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, num_experts=num_experts, - moe_gate_k=moe_gate_k, - moe_capacity_factor=moe_capacity_factor, - moe_eval_capacity_factor=moe_eval_capacity_factor, - moe_min_capacity=moe_min_capacity, - moe_noisy_gate_policy=moe_noisy_gate_policy, - moe_drop_tokens=moe_drop_tokens, - moe_use_rts=moe_use_rts, - moe_use_residual=moe_use_residual, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 4c7b741..ab18f69 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -1,5 +1,3 @@ -import typing - import torch from internlm.core.context import ParallelMode @@ -39,17 +37,9 @@ class MoE(torch.nn.Module): hidden_size, num_experts=1, ep_size=1, - topk=1, - capacity_factor=1.0, - eval_capacity_factor=1.0, - min_capacity=4, - noisy_gate_policy: typing.Optional[str] = None, - drop_tokens: bool = True, - use_rts: bool = True, - moe_type: str = None, - use_residual=False, device=None, dtype=None, + moe_type: str = None, ): super().__init__() @@ -61,30 +51,19 @@ class MoE(torch.nn.Module): self.num_experts = num_experts self.num_local_experts = num_experts // self.ep_size - assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( - "Unsupported noisy_gate_policy: " + noisy_gate_policy - ) - if moe_type is None or moe_type == "GShard": self.moe_layer = GShardMOELayer( hidden_size, gpc.get_group(ParallelMode.EXPERT), ep_size, num_experts, - topk, - capacity_factor, - eval_capacity_factor, - min_capacity, - noisy_gate_policy, - drop_tokens, - use_rts, device, dtype, ) # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence - self.use_residual = use_residual - if use_residual: + self.use_residual = getattr(gpc.config.model, "moe_use_residual", False) + if self.use_residual: self.residual_mlp = FeedForward( hidden_size, int(hidden_size * gpc.config.model.mlp_ratio), diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 1a9fd1c..ed6de15 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -304,7 +304,7 @@ class TopKGate(Module): self, model_dim: int, num_experts: int, - k: int = 1, + topk: int = 1, capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0, min_capacity: int = 8, @@ -315,11 +315,11 @@ class TopKGate(Module): super().__init__() # Only top-1 and top-2 are supported at the moment. - if k not in (1, 2): + if topk not in (1, 2): raise ValueError("Only top-1 and top-2 gatings are supported.") # Deepspeed's mechisms, alway use fp32 self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) - self.k = k + self.k = topk self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor self.min_capacity = min_capacity @@ -388,27 +388,24 @@ class GShardMOELayer(BaseMoELayer): ep_group, ep_size: int, num_experts: int, - topk, - capacity_factor, - eval_capacity_factor, - min_capacity, - noisy_gate_policy, - drop_tokens, - use_rts, device=None, dtype=None, ) -> None: + noisy_gate_policy = getattr(gpc.config.model, "noisy_gate_policy", None) + assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( + "Unsupported noisy_gate_policy: " + noisy_gate_policy + ) super().__init__( TopKGate( hidden_size, num_experts, - topk, - capacity_factor, - eval_capacity_factor, - min_capacity, - noisy_gate_policy, - drop_tokens, - use_rts, + topk=getattr(gpc.config.model, "moe_gate_k", 1), + capacity_factor=getattr(gpc.config.model, "moe_capacity_factor", 1.0), + eval_capacity_factor=getattr(gpc.config.model, "moe_eval_capacity_factor", 1.0), + min_capacity=getattr(gpc.config.model, "moe_min_capacity", 4), + noisy_gate_policy=getattr(gpc.config.model, "moe_noisy_gate_policy", None), + drop_tokens=getattr(gpc.config.model, "moe_drop_tokens", True), + use_rts=getattr(gpc.config.model, "moe_use_rts", True), ), torch.nn.ModuleList( [