fix single fwd

pull/564/head
877825076@qq.com 2023-12-29 13:41:45 +08:00
parent c437ffbfc9
commit 1d217eb94e
1 changed files with 6 additions and 3 deletions

View File

@ -185,13 +185,16 @@ class RotaryEmbedding(torch.nn.Module):
else:
return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
def _cal_max_seqlen(self, indexes, max_seqlen=None):
if not isinstance(indexes, int):
if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
max_seqlen = indexes.max().item() + 1
else:
max_seqlen = indexes + 1 # eval_forward
return max_seqlen
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
max_seqlen = self._cal_max_seqlen(indexes, max_seqlen)
self._update_cos_sin_cache(qkv, max_seqlen)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
@ -223,9 +226,9 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached[seqlen_offset:],
)
def _single_forward(self, x, indexes=0):
def _single_forward(self, x, indexes=0, **kwargs):
assert self.scale is None
self._update_cos_sin_cache(x, indexes)
self._update_cos_sin_cache(x, self._cal_max_seqlen(indexes, kwargs.get("max_seqlen", None)))
x = x[None, ...]
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
return ret