mirror of https://github.com/InternLM/InternLM
refactor(rotaryEmbedding): refactor forward (#120)
* use fp16 in instruction (#80) * delete torch_dtype of README's example code (#100) * refactor the forward for rotary embedding --------- Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>pull/139/head
parent
762ab297ee
commit
fd398fae1a
|
@ -176,7 +176,15 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, qkv: torch.Tensor, **kwargs):
|
||||
if kwargs.get("indexes", None) is not None:
|
||||
return self._forward(qkv, kwargs.pop("indexes"))
|
||||
if kwargs.get("inference_params", None) is not None:
|
||||
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
|
||||
else:
|
||||
return self._eval_forward(qkv)
|
||||
|
||||
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._update_cos_sin_cache(qkv, indexes)
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||
|
@ -189,7 +197,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self._sin_k_cached[indexes],
|
||||
)
|
||||
|
||||
def eval_forward(self, qkv, seqlen_offset=0):
|
||||
def _eval_forward(self, qkv, seqlen_offset=0):
|
||||
"""
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
|
|
|
@ -107,9 +107,9 @@ class MHA(nn.Module):
|
|||
if kwargs.get("indexes", None) is not None:
|
||||
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
||||
else:
|
||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params)
|
||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
||||
|
||||
def _forward(self, x, seqlen=None, inference_params=None):
|
||||
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||
|
@ -124,10 +124,8 @@ class MHA(nn.Module):
|
|||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
if inference_params is None:
|
||||
qkv = self.rotary_emb.eval_forward(qkv)
|
||||
else:
|
||||
qkv = self.rotary_emb.eval_forward(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv)
|
||||
|
@ -158,7 +156,8 @@ class MHA(nn.Module):
|
|||
"""
|
||||
qkv = self.Wqkv(x) # total x hsz'
|
||||
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
||||
qkv = self.rotary_emb(qkv, kwargs.pop("indexes"))
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
kwargs.pop("indexes")
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue