From 9445faf5bef62304a54b6edd5dd7c9d737e466b6 Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Tue, 5 Sep 2023 19:03:02 +0800 Subject: [PATCH] fix(model): set tensor parallel attribute for mlp (#271) * set is_tensor_parallel attribute for mlp * fix lint --- internlm/model/linear.py | 8 +------- internlm/model/modeling_internlm.py | 3 +++ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 32f29f8..5a3a4eb 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.utils import fused_dense_func_torch @@ -195,12 +195,6 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - # need to assign tp attribute so that colossalai know it is tensor parallel module - - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - for name in ["w1", "w2", "w3"]: - for param in getattr(self, name).parameters(): - setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x): out = self.w3(F.silu(self.w1(x)) * self.w2(x)) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 4494959..0ca805e 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module): device=device, dtype=dtype, ) + for _, param in self.mlp.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init