mirror of https://github.com/InternLM/InternLM
[fix bug] Fix the error that RotaryEmbedding is converted to a non-fp32 format during training, and add a compatible method for the llama model. (#239)
Co-authored-by: YWMditto <862779238@qq.com>pull/240/head
parent
54f85a6e9a
commit
28635755f5
|
@ -137,15 +137,13 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
""" """
|
""" """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Generate and save the inverse frequency buffer (non trainable)
|
# Generate and save the inverse frequency buffer (non trainable)
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
self.scale_base = scale_base
|
self.scale_base = scale_base
|
||||||
scale = (
|
self.scale = (
|
||||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||||
if scale_base > 0
|
if scale_base > 0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.register_buffer("scale", scale)
|
|
||||||
|
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
|
@ -220,3 +218,15 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self._cos_k_cached[seqlen_offset:],
|
self._cos_k_cached[seqlen_offset:],
|
||||||
self._sin_k_cached[seqlen_offset:],
|
self._sin_k_cached[seqlen_offset:],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _single_forward(self, x, indexes=0):
|
||||||
|
assert self.scale is None
|
||||||
|
self._update_cos_sin_cache(x, indexes)
|
||||||
|
x = x[None, ...]
|
||||||
|
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _single_eval_forward(self, x, seqlen_offset=0):
|
||||||
|
assert self.scale is None
|
||||||
|
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
|
||||||
|
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||||
|
|
Loading…
Reference in New Issue