mirror of https://github.com/InternLM/InternLM
210 lines
8.3 KiB
Python
210 lines
8.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Tuple
|
|
|
|
import rotary_emb
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
|
from torch import Tensor, nn
|
|
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
|
|
from .utils import gather_forward_split_backward
|
|
|
|
|
|
class Embedding1D(nn.Module):
|
|
"""
|
|
1D Embedding.
|
|
|
|
Args:
|
|
num_embeddings (int): The size of vocab.
|
|
embedding_dim (int): The dimention of model.
|
|
padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
|
|
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
|
|
i.e. it remains as a fixed "pad". None by default.
|
|
dtype (Optional[torch.dtype]): Data type None by default.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
*args,
|
|
padding_idx: int = None,
|
|
dtype: torch.dtype = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_embeddings = num_embeddings
|
|
self.embed_dim = embedding_dim
|
|
embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size
|
|
|
|
self.padding_idx = padding_idx
|
|
self.embed_args = args
|
|
self.embed_kwargs = kwargs
|
|
|
|
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
|
|
|
return output
|
|
|
|
|
|
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
"""
|
|
ApplyRotaryEmbQKV_
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
|
"""
|
|
qkv: (total, 3, nheads, headdim)
|
|
cos, sin: (seqlen, rotary_dim / 2)
|
|
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
|
rotary_dim must be <= headdim
|
|
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
|
"""
|
|
_, three, _, headdim = qkv.shape
|
|
assert three == 3
|
|
rotary_seqlen, rotary_dim = cos.shape
|
|
rotary_dim *= 2
|
|
assert rotary_dim <= headdim
|
|
cos_k = cos if cos_k is None else cos_k
|
|
sin_k = sin if sin_k is None else sin_k
|
|
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
|
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
|
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
|
|
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
|
rotary_emb.apply_rotary(
|
|
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
|
|
)
|
|
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
|
return qkv
|
|
|
|
@staticmethod
|
|
def backward(ctx, dqkv):
|
|
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
|
rotary_dim = cos.shape[-1]
|
|
rotary_dim *= 2
|
|
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
|
rotary_emb.apply_rotary(
|
|
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
|
|
)
|
|
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
|
rotary_emb.apply_rotary(
|
|
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
|
|
)
|
|
return dqkv, None, None, None, None
|
|
|
|
|
|
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
|
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
"""
|
|
The rotary position embeddings from RoFormer_ (Su et. al).
|
|
A crucial insight from the method is that the query and keys are
|
|
transformed by rotation matrices which depend on the relative positions.
|
|
|
|
Other implementations are available in the Rotary Transformer repo_ and in
|
|
GPT-NeoX_, GPT-NeoX was an inspiration
|
|
|
|
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
|
|
|
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
|
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
|
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
|
"""
|
|
|
|
def __init__(self, dim: int, base=10000, scale_base=0, device=None):
|
|
""" """
|
|
super().__init__()
|
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
self.scale_base = scale_base
|
|
scale = (
|
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
|
if scale_base > 0
|
|
else None
|
|
)
|
|
self.register_buffer("scale", scale)
|
|
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._cos_k_cached = None
|
|
self._sin_k_cached = None
|
|
|
|
def _update_cos_sin_cache(self, x, indexes):
|
|
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
|
|
if not isinstance(indexes, int):
|
|
seqlen = indexes.max().item() + 1
|
|
else:
|
|
seqlen = indexes + 1 # eval_forward
|
|
# Reset the tables if the sequence length has changed,
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
|
|
# Don't do einsum, it converts fp32 to fp16
|
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
|
if self.scale is None:
|
|
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
|
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
|
else:
|
|
power = (
|
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
|
) / self.scale_base
|
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
|
# We want the multiplication by scale to happen in fp32
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
|
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
|
|
|
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
self._update_cos_sin_cache(qkv, indexes)
|
|
if self.scale is None:
|
|
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
|
else:
|
|
return apply_rotary_emb_qkv_(
|
|
qkv,
|
|
self._cos_cached[indexes],
|
|
self._sin_cached[indexes],
|
|
self._cos_k_cached[indexes],
|
|
self._sin_k_cached[indexes],
|
|
)
|
|
|
|
def eval_forward(self, qkv, seqlen_offset=0):
|
|
"""
|
|
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
token in the batch.
|
|
"""
|
|
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
|
|
if self.scale is None:
|
|
return legacy_apply_rotary_embed_qkv(
|
|
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
|
|
)
|
|
else:
|
|
return legacy_apply_rotary_embed_qkv(
|
|
qkv,
|
|
self._cos_cached[seqlen_offset:],
|
|
self._sin_cached[seqlen_offset:],
|
|
self._cos_k_cached[seqlen_offset:],
|
|
self._sin_k_cached[seqlen_offset:],
|
|
)
|