mirror of https://github.com/InternLM/InternLM
refactor(rotaryEmbedding): refactor forward (#120)
* use fp16 in instruction (#80) * delete torch_dtype of README's example code (#100) * refactor the forward for rotary embedding --------- Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>pull/139/head
parent
762ab297ee
commit
fd398fae1a
|
@ -176,7 +176,15 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||||
|
|
||||||
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, qkv: torch.Tensor, **kwargs):
|
||||||
|
if kwargs.get("indexes", None) is not None:
|
||||||
|
return self._forward(qkv, kwargs.pop("indexes"))
|
||||||
|
if kwargs.get("inference_params", None) is not None:
|
||||||
|
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
|
||||||
|
else:
|
||||||
|
return self._eval_forward(qkv)
|
||||||
|
|
||||||
|
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
self._update_cos_sin_cache(qkv, indexes)
|
self._update_cos_sin_cache(qkv, indexes)
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||||
|
@ -189,7 +197,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self._sin_k_cached[indexes],
|
self._sin_k_cached[indexes],
|
||||||
)
|
)
|
||||||
|
|
||||||
def eval_forward(self, qkv, seqlen_offset=0):
|
def _eval_forward(self, qkv, seqlen_offset=0):
|
||||||
"""
|
"""
|
||||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||||
token in the batch.
|
token in the batch.
|
||||||
|
|
|
@ -107,9 +107,9 @@ class MHA(nn.Module):
|
||||||
if kwargs.get("indexes", None) is not None:
|
if kwargs.get("indexes", None) is not None:
|
||||||
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params)
|
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
||||||
|
|
||||||
def _forward(self, x, seqlen=None, inference_params=None):
|
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||||
|
@ -124,10 +124,8 @@ class MHA(nn.Module):
|
||||||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||||
|
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
if inference_params is None:
|
kwargs["inference_params"] = inference_params
|
||||||
qkv = self.rotary_emb.eval_forward(qkv)
|
qkv = self.rotary_emb(qkv, **kwargs)
|
||||||
else:
|
|
||||||
qkv = self.rotary_emb.eval_forward(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
|
||||||
|
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
context = self.inner_attn(qkv)
|
context = self.inner_attn(qkv)
|
||||||
|
@ -158,7 +156,8 @@ class MHA(nn.Module):
|
||||||
"""
|
"""
|
||||||
qkv = self.Wqkv(x) # total x hsz'
|
qkv = self.Wqkv(x) # total x hsz'
|
||||||
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
||||||
qkv = self.rotary_emb(qkv, kwargs.pop("indexes"))
|
qkv = self.rotary_emb(qkv, **kwargs)
|
||||||
|
kwargs.pop("indexes")
|
||||||
|
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
context = self.inner_attn(qkv, **kwargs)
|
context = self.inner_attn(qkv, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue