diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index 751bce5..9bdbd94 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -179,20 +179,20 @@ class RotaryEmbedding(torch.nn.Module): def forward(self, qkv: torch.Tensor, **kwargs): if kwargs.get("indexes", None) is not None: - return self._forward(qkv, kwargs.pop("indexes")) + return self._forward(qkv, kwargs.pop("indexes"), kwargs.get("max_seqlen", None)) if kwargs.get("inference_params", None) is not None: return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset) else: return self._eval_forward(qkv) - def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(indexes, int): - if seqlen is None: # We try to avoid trying item calls in fwd and bwd. - seqlen = indexes.max().item() + 1 + if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd. + max_seqlen = indexes.max().item() + 1 else: - seqlen = indexes + 1 # eval_forward + max_seqlen = indexes + 1 # eval_forward - self._update_cos_sin_cache(qkv, seqlen) + self._update_cos_sin_cache(qkv, max_seqlen) if self.scale is None: return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) else: