diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index b1af14f..1029f0d 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -3,7 +3,7 @@ import math from functools import wraps -from typing import Optional, Union +from typing import Optional import torch from flash_attn.modules.embedding import ParallelGPT2Embeddings @@ -381,13 +381,7 @@ class PackedFlashInternLm1D(nn.Module): self.parallel_output = parallel_output def forward( - self, - hidden_states=None, - cu_seqlens=None, - input_ids=None, - indexes=None, - inference_params=None, - max_seqlen: Optional[Union[int, None]] = None, + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs ): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "embedding"): @@ -410,8 +404,13 @@ class PackedFlashInternLm1D(nn.Module): # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - if cu_seqlens is not None and max_seqlen is None: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + if cu_seqlens is not None: + if "max_seqlen" not in kwargs: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + else: + max_seqlen = kwargs.pop("max_seqlen") + else: + max_seqlen = None for _, block in enumerate(self.blocks): hidden_states = block(