diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 0672422..6abb17b 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -143,7 +143,7 @@ model = dict( num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_experts=4, moe_use_residual=False, - moe_gate_k=2, + moe_type="GShard", ) # zero1 parallel: @@ -176,6 +176,17 @@ monitor = dict( ), ) +# custom moe impl configs +moe = dict( + top_k=2, + capacity_factor=1.0, + eval_capacity_factor=1.0, + min_capacity=4, + noisy_gate_policy=None, + drop_tokens=True, + use_rts=True, +) + model_type = "INTERNLM_MoE" # metric_dtype can be "fp32" or other string diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 7d6badc..b0463a3 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -305,8 +305,8 @@ def args_sanity_check(): model._add_item("num_experts", 1) if "moe_use_residual" not in model: model._add_item("moe_use_residual", False) - if "moe_gate_k" not in model: - model._add_item("moe_gate_k", 2) + if "moe_type" not in model: + model._add_item("moe_type", None) # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index ff6197d..e42ba25 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -53,16 +53,9 @@ class PackedFlashBaseLayer1D(nn.Module): norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. use_flash_attn (bool): Whether use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. - moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. - moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. - moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. - moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. - moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to - infinite capacity). - moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + moe_type (str): determine which moe impl will be used, default is GShardMoE """ def __init__( @@ -158,6 +151,7 @@ class PackedFlashBaseLayer1D(nn.Module): self.mlp = MoE( hidden_size=hidden_size, num_experts=num_experts, + ep_group=gpc.get_group(ParallelMode.EXPERT), ep_size=ep_size, device=device, dtype=dtype, @@ -292,16 +286,9 @@ class PackedFlashInternLm1D(nn.Module): norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. use_flash_attn (bool): Whether to use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. - moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. - moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. - moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. - moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. - moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent - to infinite capacity). - moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + moe_type (str): determine which moe impl will be used, default is GShardMoE """ def __init__( @@ -519,13 +506,6 @@ def build_model_with_moe_cfg( use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: int = 1, - moe_gate_k: int = 1, # pylint: disable=W0613 - moe_capacity_factor: float = 1.0, # pylint: disable=W0613 - moe_eval_capacity_factor: float = 1.0, # pylint: disable=W0613 - moe_min_capacity: int = 4, # pylint: disable=W0613 - moe_noisy_gate_policy: str = None, # pylint: disable=W0613 - moe_drop_tokens: bool = True, # pylint: disable=W0613 - moe_use_rts: bool = True, # pylint: disable=W0613 moe_use_residual: bool = False, # pylint: disable=W0613 moe_type: str = None, # pylint: disable=W0613 ): @@ -559,16 +539,9 @@ def build_model_with_moe_cfg( use_swiglu (bool): Whether to use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. - moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. - moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. - moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. - moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. - moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent - to infinite capacity). - moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + moe_type (str): determine which moe impl will be used, default is GShardMoE """ cfg = dict( diff --git a/internlm/model/moe.py b/internlm/model/moe.py index ff37a6d..e4902c8 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -36,6 +36,7 @@ class MoE(torch.nn.Module): self, hidden_size, num_experts=1, + ep_group=None, ep_size=1, device=None, dtype=None, @@ -43,27 +44,23 @@ class MoE(torch.nn.Module): super().__init__() - assert ( - num_experts % ep_size == 0 - ), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" - self.ep_size = ep_size - self.num_experts = num_experts - self.num_local_experts = num_experts // self.ep_size + moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None)) - moe_type = getattr(gpc.config.model, "moe_type", None) + if not hasattr(gpc.config, "moe"): + gpc.config.moe = dict() - if moe_type is None or moe_type == "GShard": - self.moe_layer = GShardMOELayer( - hidden_size, - gpc.get_group(ParallelMode.EXPERT), - ep_size, - num_experts, - device, - dtype, - ) + self.moe_layer = moe_impl( + hidden_size=hidden_size, + num_experts=num_experts, + ep_group=ep_group, + ep_size=ep_size, + device=device, + dtype=dtype, + **(gpc.config.moe) + ) # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence - self.use_residual = getattr(gpc.config.model, "moe_use_residual", False) + self.use_residual = gpc.config.model.moe_use_residual if self.use_residual: self.residual_mlp = FeedForward( hidden_size, @@ -77,6 +74,10 @@ class MoE(torch.nn.Module): # coefficient is used for weighted sum of the output of expert and residual mlp self.coefficient = torch.nn.Linear(hidden_size, 2) + def get_moe(self, moe_type): + if moe_type is None or moe_type == "GShard": + return GShardMOELayer + def forward(self, hidden_states, used_token=None): """MoE forward diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index ed6de15..b71c1ee 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -385,27 +385,36 @@ class GShardMOELayer(BaseMoELayer): def __init__( self, hidden_size, + num_experts: int, ep_group, ep_size: int, - num_experts: int, + top_k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 4, + noisy_gate_policy: str = None, + drop_tokens: bool = True, + use_rts: bool = True, device=None, dtype=None, ) -> None: - noisy_gate_policy = getattr(gpc.config.model, "noisy_gate_policy", None) assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( "Unsupported noisy_gate_policy: " + noisy_gate_policy ) + assert ( + num_experts % ep_size == 0 + ), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" super().__init__( TopKGate( hidden_size, num_experts, - topk=getattr(gpc.config.model, "moe_gate_k", 1), - capacity_factor=getattr(gpc.config.model, "moe_capacity_factor", 1.0), - eval_capacity_factor=getattr(gpc.config.model, "moe_eval_capacity_factor", 1.0), - min_capacity=getattr(gpc.config.model, "moe_min_capacity", 4), - noisy_gate_policy=getattr(gpc.config.model, "moe_noisy_gate_policy", None), - drop_tokens=getattr(gpc.config.model, "moe_drop_tokens", True), - use_rts=getattr(gpc.config.model, "moe_use_rts", True), + top_k, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts, ), torch.nn.ModuleList( [