refactor(rotaryEmbedding): refactor forward (#120)

* use fp16 in instruction (#80)

* delete torch_dtype of README's example code (#100)

* refactor the forward for rotary embedding

---------

Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
pull/139/head
ytxiong 2023-07-25 15:25:48 +08:00 committed by GitHub
parent 762ab297ee
commit fd398fae1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 10 deletions

View File

@ -175,8 +175,16 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, qkv: torch.Tensor, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._forward(qkv, kwargs.pop("indexes"))
if kwargs.get("inference_params", None) is not None:
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
else:
return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
@ -189,7 +197,7 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached[indexes],
)
def eval_forward(self, qkv, seqlen_offset=0):
def _eval_forward(self, qkv, seqlen_offset=0):
"""
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.

View File

@ -107,9 +107,9 @@ class MHA(nn.Module):
if kwargs.get("indexes", None) is not None:
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
else:
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params)
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
def _forward(self, x, seqlen=None, inference_params=None):
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
@ -124,10 +124,8 @@ class MHA(nn.Module):
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
if inference_params is None:
qkv = self.rotary_emb.eval_forward(qkv)
else:
qkv = self.rotary_emb.eval_forward(qkv, seqlen_offset=inference_params.sequence_len_offset)
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
if inference_params is None:
context = self.inner_attn(qkv)
@ -158,7 +156,8 @@ class MHA(nn.Module):
"""
qkv = self.Wqkv(x) # total x hsz'
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
qkv = self.rotary_emb(qkv, kwargs.pop("indexes"))
qkv = self.rotary_emb(qkv, **kwargs)
kwargs.pop("indexes")
if inference_params is None:
context = self.inner_attn(qkv, **kwargs)