pull/564/head
877825076@qq.com 2023-12-28 19:44:35 +08:00
parent 06ececeb00
commit 83989b57ae
1 changed files with 6 additions and 6 deletions

View File

@ -179,20 +179,20 @@ class RotaryEmbedding(torch.nn.Module):
def forward(self, qkv: torch.Tensor, **kwargs): def forward(self, qkv: torch.Tensor, **kwargs):
if kwargs.get("indexes", None) is not None: if kwargs.get("indexes", None) is not None:
return self._forward(qkv, kwargs.pop("indexes")) return self._forward(qkv, kwargs.pop("indexes"), kwargs.get("max_seqlen", None))
if kwargs.get("inference_params", None) is not None: if kwargs.get("inference_params", None) is not None:
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset) return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
else: else:
return self._eval_forward(qkv) return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
if not isinstance(indexes, int): if not isinstance(indexes, int):
if seqlen is None: # We try to avoid trying item calls in fwd and bwd. if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
seqlen = indexes.max().item() + 1 max_seqlen = indexes.max().item() + 1
else: else:
seqlen = indexes + 1 # eval_forward max_seqlen = indexes + 1 # eval_forward
self._update_cos_sin_cache(qkv, seqlen) self._update_cos_sin_cache(qkv, max_seqlen)
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
else: else: