diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ee434f3..ac2733e 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -73,6 +73,14 @@ class PackedFlashBaseLayer1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, ): super().__init__() self.checkpoint = checkpoint @@ -135,7 +143,7 @@ class PackedFlashBaseLayer1D(nn.Module): dtype=dtype, ) else: - expert = torch.nn.ModuleList( + experts = torch.nn.ModuleList( [ FeedForward( hidden_size, @@ -149,9 +157,34 @@ class PackedFlashBaseLayer1D(nn.Module): for i in range(num_experts // ep_size) ] ) - # TODO: test moe for now, need more parameter such as: capacity_factor, - # eval_capacity_factor, min_capacity, drop_tokens - self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1) + + if moe_use_residual: + residual_mlp = FeedForward( + hidden_size, + int(hidden_size * gpc.config.model.mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=torch.device("cuda"), + dtype=torch.float, + ) + + self.mlp = MoE( + hidden_size=hidden_size, + experts=experts, + num_experts=num_experts, + ep_size=ep_size, + k=moe_gate_k, + capacity_factor=moe_capacity_factor, + eval_capacity_factor=moe_eval_capacity_factor, + min_capacity=moe_min_capacity, + noisy_gate_policy=moe_noisy_gate_policy, + drop_tokens=moe_drop_tokens, + use_rts=moe_use_rts, + use_residual=moe_use_residual, + residual_mlp=residual_mlp if moe_use_residual else None, + ) + self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init @@ -309,6 +342,14 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, num_experts: bool = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, ): super().__init__() @@ -361,6 +402,14 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, ) for lid in range(num_layers) ] @@ -499,6 +548,14 @@ def build_model_with_cfg( use_flash_attn: bool = True, sequence_parallel: bool = False, # pylint: disable=W0613 num_experts: int = 1, + moe_gate_k: int = 1, + moe_capacity_factor: float = 1.0, + moe_eval_capacity_factor: float = 1.0, + moe_min_capacity: int = 4, + moe_noisy_gate_policy: str = None, + moe_drop_tokens: bool = True, + moe_use_rts: bool = True, + moe_use_residual: bool = False, ): """ Builde model with config @@ -555,6 +612,14 @@ def build_model_with_cfg( use_flash_attn=use_flash_attn, sequence_parallel=sequence_parallel, num_experts=num_experts, + moe_gate_k=moe_gate_k, + moe_capacity_factor=moe_capacity_factor, + moe_eval_capacity_factor=moe_eval_capacity_factor, + moe_min_capacity=moe_min_capacity, + moe_noisy_gate_policy=moe_noisy_gate_policy, + moe_drop_tokens=moe_drop_tokens, + moe_use_rts=moe_use_rts, + moe_use_residual=moe_use_residual, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 180d829..4c2722c 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -58,7 +58,7 @@ class MoE(torch.nn.Module): def __init__( self, hidden_size, - expert, + experts, num_experts=1, ep_size=1, k=1, @@ -69,6 +69,8 @@ class MoE(torch.nn.Module): drop_tokens: bool = True, use_rts: bool = True, using_default_moe: bool = True, + use_residual=True, + residual_mlp=None ): super().__init__() @@ -89,7 +91,7 @@ class MoE(torch.nn.Module): "Unsupported noisy_gate_policy: " + noisy_gate_policy ) - experts = Experts(expert, self.num_local_experts) + experts = Experts(experts, self.num_local_experts) if using_default_moe: self.moe_layer = MOELayer( @@ -110,6 +112,12 @@ class MoE(torch.nn.Module): self.num_local_experts, ) + self.use_residual = use_residual + if use_residual: + self.residual_mlp = residual_mlp + # coefficient is used for weighted sum of the output of expert and mlp + self.coefficient = torch.nn.Linear(hidden_size, 2) + def forward(self, hidden_states, used_token=None): """MoE forward @@ -127,5 +135,12 @@ class MoE(torch.nn.Module): * exp_counts (int): expert count """ output = self.moe_layer(hidden_states, used_token) - + if self.use_residual: + # Residual MoE + output_mlp = self.residual_mlp(hidden_states) + if type(output_mlp) is tuple: + output_mlp = output_mlp[0] # Ignore the bias term for now + coef = self.coefficient(hidden_states) + coef = torch.nn.functional.softmax(coef, dim=-1) + output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] return output, self.moe_layer.l_aux, self.moe_layer.exp_counts