mirror of https://github.com/InternLM/InternLM
fix single fwd
parent
c437ffbfc9
commit
1d217eb94e
|
@ -185,13 +185,16 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
return self._eval_forward(qkv)
|
return self._eval_forward(qkv)
|
||||||
|
|
||||||
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _cal_max_seqlen(self, indexes, max_seqlen=None):
|
||||||
if not isinstance(indexes, int):
|
if not isinstance(indexes, int):
|
||||||
if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
|
if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
|
||||||
max_seqlen = indexes.max().item() + 1
|
max_seqlen = indexes.max().item() + 1
|
||||||
else:
|
else:
|
||||||
max_seqlen = indexes + 1 # eval_forward
|
max_seqlen = indexes + 1 # eval_forward
|
||||||
|
return max_seqlen
|
||||||
|
|
||||||
|
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
max_seqlen = self._cal_max_seqlen(indexes, max_seqlen)
|
||||||
self._update_cos_sin_cache(qkv, max_seqlen)
|
self._update_cos_sin_cache(qkv, max_seqlen)
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||||
|
@ -223,9 +226,9 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self._sin_k_cached[seqlen_offset:],
|
self._sin_k_cached[seqlen_offset:],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _single_forward(self, x, indexes=0):
|
def _single_forward(self, x, indexes=0, **kwargs):
|
||||||
assert self.scale is None
|
assert self.scale is None
|
||||||
self._update_cos_sin_cache(x, indexes)
|
self._update_cos_sin_cache(x, self._cal_max_seqlen(indexes, kwargs.get("max_seqlen", None)))
|
||||||
x = x[None, ...]
|
x = x[None, ...]
|
||||||
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
|
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
|
||||||
return ret
|
return ret
|
||||||
|
|
Loading…
Reference in New Issue