diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index 8c59aaf..d4ae9b5 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -137,15 +137,13 @@ class RotaryEmbedding(torch.nn.Module): """ """ super().__init__() # Generate and save the inverse frequency buffer (non trainable) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq) + self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) self.scale_base = scale_base - scale = ( + self.scale = ( (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None ) - self.register_buffer("scale", scale) self._seq_len_cached = 0 self._cos_cached = None @@ -220,3 +218,15 @@ class RotaryEmbedding(torch.nn.Module): self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], ) + + def _single_forward(self, x, indexes=0): + assert self.scale is None + self._update_cos_sin_cache(x, indexes) + x = x[None, ...] + ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) + return ret + + def _single_eval_forward(self, x, seqlen_offset=0): + assert self.scale is None + self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) + return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])