[fix bug] Fix the error that RotaryEmbedding is converted to a non-fp32 format during training, and add a compatible method for the llama model. (#239)

Co-authored-by: YWMditto <862779238@qq.com>
pull/240/head
YWMditto 2023-08-26 17:48:08 +08:00 committed by GitHub
parent 54f85a6e9a
commit 28635755f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 4 deletions

View File

@ -137,15 +137,13 @@ class RotaryEmbedding(torch.nn.Module):
""" """
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq)
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.scale_base = scale_base
scale = (
self.scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base > 0
else None
)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
@ -220,3 +218,15 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
)
def _single_forward(self, x, indexes=0):
assert self.scale is None
self._update_cos_sin_cache(x, indexes)
x = x[None, ...]
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
return ret
def _single_eval_forward(self, x, seqlen_offset=0):
assert self.scale is None
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])