From e9208728cbf096b6fdb0e6d7456053ade37061b4 Mon Sep 17 00:00:00 2001 From: "877825076@qq.com" <877825076@qq.com> Date: Fri, 29 Dec 2023 16:47:19 +0800 Subject: [PATCH] fix --- internlm/model/modeling_moe.py | 13 +++++++++++-- tests/test_core/utils.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a8..cac2e43 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -448,7 +448,9 @@ class PackedFlashInternLm1D(nn.Module): setattr(param, IS_TENSOR_PARALLEL, True) self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward( + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs + ): # attention_mask: compute attention on the places where the value is 1 # old condition may fail when use shared embedding if gpc.is_pipeline_first_stage(): @@ -470,7 +472,14 @@ class PackedFlashInternLm1D(nn.Module): assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + if cu_seqlens is not None: + if "max_seqlen" not in kwargs: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + else: + max_seqlen = kwargs.pop("max_seqlen") + else: + max_seqlen = None moe_losses = [] for _, block in enumerate(self.blocks): diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 6f66a15..91561a1 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -33,7 +33,7 @@ class MlpModel(nn.Module): self.embedding = embedding def forward( - self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs ): # pylint: disable=W0613 if self.model_type != "torch" and self.part[0] != 0: input_ids = hidden_states