mirror of https://github.com/InternLM/InternLM
fix single fwd
parent
c437ffbfc9
commit
1d217eb94e
|
@ -185,13 +185,16 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
else:
|
||||
return self._eval_forward(qkv)
|
||||
|
||||
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _cal_max_seqlen(self, indexes, max_seqlen=None):
|
||||
if not isinstance(indexes, int):
|
||||
if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
|
||||
max_seqlen = indexes.max().item() + 1
|
||||
else:
|
||||
max_seqlen = indexes + 1 # eval_forward
|
||||
return max_seqlen
|
||||
|
||||
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen = self._cal_max_seqlen(indexes, max_seqlen)
|
||||
self._update_cos_sin_cache(qkv, max_seqlen)
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||
|
@ -223,9 +226,9 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self._sin_k_cached[seqlen_offset:],
|
||||
)
|
||||
|
||||
def _single_forward(self, x, indexes=0):
|
||||
def _single_forward(self, x, indexes=0, **kwargs):
|
||||
assert self.scale is None
|
||||
self._update_cos_sin_cache(x, indexes)
|
||||
self._update_cos_sin_cache(x, self._cal_max_seqlen(indexes, kwargs.get("max_seqlen", None)))
|
||||
x = x[None, ...]
|
||||
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue