feat(linear): optimize mlp by using jit (#321)

* fuse silu op

* refactor code

* fix lint

* fix lint
pull/322/head
ytxiong 2023-09-19 14:57:43 +08:00 committed by GitHub
parent 025ca55dfe
commit 6a5915bf0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 3 deletions

View File

@ -4,14 +4,13 @@
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch
from internlm.model.utils import Silu, fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear):
@ -197,5 +196,7 @@ class FeedForward(nn.Module):
)
def forward(self, x):
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
w1_o = self.w1(x)
w2_o = self.w2(x)
out = self.w3(Silu(w1_o, w2_o))
return out

View File

@ -130,6 +130,7 @@ class PackedFlashBaseLayer1D(nn.Module):
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init

View File

@ -207,3 +207,10 @@ def try_import_RMSNorm():
from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm
def Silu(w1_o, w2_o):
return F.silu(w1_o) * w2_o
Silu = torch.jit.script(Silu)