InternLM/internlm/model/embedding.py

210 lines
8.3 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Tuple
import rotary_emb
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
from torch import Tensor, nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from .utils import gather_forward_split_backward
class Embedding1D(nn.Module):
"""
1D Embedding.
Args:
num_embeddings (int): The size of vocab.
embedding_dim (int): The dimention of model.
padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
i.e. it remains as a fixed "pad". None by default.
dtype (Optional[torch.dtype]): Data type None by default.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
*args,
padding_idx: int = None,
dtype: torch.dtype = None,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
return output
class ApplyRotaryEmbQKV_(torch.autograd.Function):
"""
ApplyRotaryEmbQKV_
"""
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
"""
qkv: (total, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
_, three, _, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
return qkv
@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
)
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
)
return dqkv, None, None, None, None
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000, scale_base=0, device=None):
""" """
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq)
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base > 0
else None
)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
else:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached[indexes],
self._sin_cached[indexes],
self._cos_k_cached[indexes],
self._sin_k_cached[indexes],
)
def eval_forward(self, qkv, seqlen_offset=0):
"""
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
if self.scale is None:
return legacy_apply_rotary_embed_qkv(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
)
else:
return legacy_apply_rotary_embed_qkv(
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
)