mirror of https://github.com/InternLM/InternLM
fix
parent
06ececeb00
commit
83989b57ae
|
@ -179,20 +179,20 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, qkv: torch.Tensor, **kwargs):
|
def forward(self, qkv: torch.Tensor, **kwargs):
|
||||||
if kwargs.get("indexes", None) is not None:
|
if kwargs.get("indexes", None) is not None:
|
||||||
return self._forward(qkv, kwargs.pop("indexes"))
|
return self._forward(qkv, kwargs.pop("indexes"), kwargs.get("max_seqlen", None))
|
||||||
if kwargs.get("inference_params", None) is not None:
|
if kwargs.get("inference_params", None) is not None:
|
||||||
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
|
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
|
||||||
else:
|
else:
|
||||||
return self._eval_forward(qkv)
|
return self._eval_forward(qkv)
|
||||||
|
|
||||||
def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if not isinstance(indexes, int):
|
if not isinstance(indexes, int):
|
||||||
if seqlen is None: # We try to avoid trying item calls in fwd and bwd.
|
if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
|
||||||
seqlen = indexes.max().item() + 1
|
max_seqlen = indexes.max().item() + 1
|
||||||
else:
|
else:
|
||||||
seqlen = indexes + 1 # eval_forward
|
max_seqlen = indexes + 1 # eval_forward
|
||||||
|
|
||||||
self._update_cos_sin_cache(qkv, seqlen)
|
self._update_cos_sin_cache(qkv, max_seqlen)
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue