diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 5a3a4eb..d18308a 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -4,14 +4,13 @@ from typing import Optional import torch -import torch.nn.functional as F from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import fused_dense_func_torch +from internlm.model.utils import Silu, fused_dense_func_torch class ScaleColumnParallelLinear(nn.Linear): @@ -197,5 +196,7 @@ class FeedForward(nn.Module): ) def forward(self, x): - out = self.w3(F.silu(self.w1(x)) * self.w2(x)) + w1_o = self.w1(x) + w2_o = self.w2(x) + out = self.w3(Silu(w1_o, w2_o)) return out diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 64ff4de..651a629 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -130,6 +130,7 @@ class PackedFlashBaseLayer1D(nn.Module): for _, param in self.mlp.named_parameters(): if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) + self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 12f80e3..76ba1a5 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -207,3 +207,10 @@ def try_import_RMSNorm(): from internlm.model.norm import RMSNormTorch as RMSNorm return RMSNorm + + +def Silu(w1_o, w2_o): + return F.silu(w1_o) * w2_o + + +Silu = torch.jit.script(Silu)