mirror of https://github.com/InternLM/InternLM
feat(linear): optimize mlp by using jit (#321)
* fuse silu op * refactor code * fix lint * fix lintpull/322/head
parent
025ca55dfe
commit
6a5915bf0d
|
@ -4,14 +4,13 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import fused_dense_func_torch
|
from internlm.model.utils import Silu, fused_dense_func_torch
|
||||||
|
|
||||||
|
|
||||||
class ScaleColumnParallelLinear(nn.Linear):
|
class ScaleColumnParallelLinear(nn.Linear):
|
||||||
|
@ -197,5 +196,7 @@ class FeedForward(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
|
w1_o = self.w1(x)
|
||||||
|
w2_o = self.w2(x)
|
||||||
|
out = self.w3(Silu(w1_o, w2_o))
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -130,6 +130,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
for _, param in self.mlp.named_parameters():
|
for _, param in self.mlp.named_parameters():
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
self.use_swiglu = use_swiglu
|
self.use_swiglu = use_swiglu
|
||||||
self.use_scaled_init = use_scaled_init
|
self.use_scaled_init = use_scaled_init
|
||||||
|
|
|
@ -207,3 +207,10 @@ def try_import_RMSNorm():
|
||||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||||
|
|
||||||
return RMSNorm
|
return RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
def Silu(w1_o, w2_o):
|
||||||
|
return F.silu(w1_o) * w2_o
|
||||||
|
|
||||||
|
|
||||||
|
Silu = torch.jit.script(Silu)
|
||||||
|
|
Loading…
Reference in New Issue