mirror of https://github.com/InternLM/InternLM
fix(model): set tensor parallel attribute for mlp (#271)
* set is_tensor_parallel attribute for mlp * fix lintpull/281/head
parent
0fb8d4141f
commit
9445faf5be
|
@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import fused_dense_func_torch
|
from internlm.model.utils import fused_dense_func_torch
|
||||||
|
|
||||||
|
@ -195,12 +195,6 @@ class FeedForward(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
# need to assign tp attribute so that colossalai know it is tensor parallel module
|
|
||||||
|
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
|
||||||
for name in ["w1", "w2", "w3"]:
|
|
||||||
for param in getattr(self, name).parameters():
|
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
|
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
|
||||||
|
|
|
@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
for _, param in self.mlp.named_parameters():
|
||||||
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
self.use_swiglu = use_swiglu
|
self.use_swiglu = use_swiglu
|
||||||
self.use_scaled_init = use_scaled_init
|
self.use_scaled_init = use_scaled_init
|
||||||
|
|
Loading…
Reference in New Issue