From 1d217eb94e14047383dc1fd29ec477aa21646857 Mon Sep 17 00:00:00 2001 From: "877825076@qq.com" <877825076@qq.com> Date: Fri, 29 Dec 2023 13:41:45 +0800 Subject: [PATCH] fix single fwd --- internlm/model/embedding.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index 9bdbd94..ed69a06 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -185,13 +185,16 @@ class RotaryEmbedding(torch.nn.Module): else: return self._eval_forward(qkv) - def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: + def _cal_max_seqlen(self, indexes, max_seqlen=None): if not isinstance(indexes, int): if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd. max_seqlen = indexes.max().item() + 1 else: max_seqlen = indexes + 1 # eval_forward + return max_seqlen + def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: + max_seqlen = self._cal_max_seqlen(indexes, max_seqlen) self._update_cos_sin_cache(qkv, max_seqlen) if self.scale is None: return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) @@ -223,9 +226,9 @@ class RotaryEmbedding(torch.nn.Module): self._sin_k_cached[seqlen_offset:], ) - def _single_forward(self, x, indexes=0): + def _single_forward(self, x, indexes=0, **kwargs): assert self.scale is None - self._update_cos_sin_cache(x, indexes) + self._update_cos_sin_cache(x, self._cal_max_seqlen(indexes, kwargs.get("max_seqlen", None))) x = x[None, ...] ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) return ret