diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ee434f3..09a10c8 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -52,6 +52,14 @@ class PackedFlashBaseLayer1D(nn.Module): norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. use_flash_attn (bool): Whether use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. """ def __init__( @@ -73,6 +81,14 @@ class PackedFlashBaseLayer1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, ): super().__init__() self.checkpoint = checkpoint @@ -107,6 +123,14 @@ class PackedFlashBaseLayer1D(nn.Module): # TODO: replace num_experts and epsize with function parameter self.num_experts = num_experts + self.moe_gate_k = moe_gate_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.moe_min_capacity = moe_min_capacity + self.moe_noisy_gate_policy = moe_noisy_gate_policy + self.moe_drop_tokens = moe_drop_tokens + self.moe_use_rts = moe_use_rts + self.moe_use_residual = moe_use_residual ep_size = gpc.get_world_size(ParallelMode.EXPERT) if num_experts <= 1: # dense, not MoE if use_swiglu: @@ -135,7 +159,7 @@ class PackedFlashBaseLayer1D(nn.Module): dtype=dtype, ) else: - expert = torch.nn.ModuleList( + experts = torch.nn.ModuleList( [ FeedForward( hidden_size, @@ -149,9 +173,34 @@ class PackedFlashBaseLayer1D(nn.Module): for i in range(num_experts // ep_size) ] ) - # TODO: test moe for now, need more parameter such as: capacity_factor, - # eval_capacity_factor, min_capacity, drop_tokens - self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1) + + if moe_use_residual: + residual_mlp = FeedForward( + hidden_size, + int(hidden_size * gpc.config.model.mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=torch.device("cuda"), + dtype=torch.float, + ) + + self.mlp = MoE( + hidden_size=hidden_size, + experts=experts, + num_experts=num_experts, + ep_size=ep_size, + k=moe_gate_k, + capacity_factor=moe_capacity_factor, + eval_capacity_factor=moe_eval_capacity_factor, + min_capacity=moe_min_capacity, + noisy_gate_policy=moe_noisy_gate_policy, + drop_tokens=moe_drop_tokens, + use_rts=moe_use_rts, + use_residual=moe_use_residual, + residual_mlp=residual_mlp if moe_use_residual else None, + ) + self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init @@ -278,7 +327,14 @@ class PackedFlashInternLm1D(nn.Module): norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. use_flash_attn (bool): Whether to use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. """ def __init__( @@ -309,6 +365,14 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: bool = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, ): super().__init__() @@ -361,6 +425,14 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, ) for lid in range(num_layers) ] @@ -499,6 +571,14 @@ def build_model_with_cfg( use_flash_attn: bool = True, sequence_parallel: bool = False, # pylint: disable=W0613 num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, ): """ Builde model with config @@ -530,7 +610,14 @@ def build_model_with_cfg( use_swiglu (bool): Whether to use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - + moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. + moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). + moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. + moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. """ cfg = dict( @@ -555,6 +642,14 @@ def build_model_with_cfg( use_flash_attn=use_flash_attn, sequence_parallel=sequence_parallel, num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 180d829..4c2722c 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -58,7 +58,7 @@ class MoE(torch.nn.Module): def __init__( self, hidden_size, - expert, + experts, num_experts=1, ep_size=1, k=1, @@ -69,6 +69,8 @@ class MoE(torch.nn.Module): drop_tokens: bool = True, use_rts: bool = True, using_default_moe: bool = True, + use_residual=True, + residual_mlp=None ): super().__init__() @@ -89,7 +91,7 @@ class MoE(torch.nn.Module): "Unsupported noisy_gate_policy: " + noisy_gate_policy ) - experts = Experts(expert, self.num_local_experts) + experts = Experts(experts, self.num_local_experts) if using_default_moe: self.moe_layer = MOELayer( @@ -110,6 +112,12 @@ class MoE(torch.nn.Module): self.num_local_experts, ) + self.use_residual = use_residual + if use_residual: + self.residual_mlp = residual_mlp + # coefficient is used for weighted sum of the output of expert and mlp + self.coefficient = torch.nn.Linear(hidden_size, 2) + def forward(self, hidden_states, used_token=None): """MoE forward @@ -127,5 +135,12 @@ class MoE(torch.nn.Module): * exp_counts (int): expert count """ output = self.moe_layer(hidden_states, used_token) - + if self.use_residual: + # Residual MoE + output_mlp = self.residual_mlp(hidden_states) + if type(output_mlp) is tuple: + output_mlp = output_mlp[0] # Ignore the bias term for now + coef = self.coefficient(hidden_states) + coef = torch.nn.functional.softmax(coef, dim=-1) + output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] return output, self.moe_layer.l_aux, self.moe_layer.exp_counts