#!/usr/bin/env python # -*- encoding: utf-8 -*- from typing import Tuple import rotary_emb import torch import torch.nn.functional as F from einops import rearrange from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_ from torch import Tensor, nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from .utils import gather_forward_split_backward class Embedding1D(nn.Module): """ 1D Embedding. Args: num_embeddings (int): The size of vocab. embedding_dim (int): The dimention of model. padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". None by default. dtype (Optional[torch.dtype]): Data type None by default. """ def __init__( self, num_embeddings: int, embedding_dim: int, *args, padding_idx: int = None, dtype: torch.dtype = None, **kwargs, ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) return output class ApplyRotaryEmbQKV_(torch.autograd.Function): """ ApplyRotaryEmbQKV_ """ @staticmethod def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): """ qkv: (total, 3, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) cos_k, sin_k: (seqlen, rotary_dim / 2), optional rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of q and k. """ _, three, _, headdim = qkv.shape assert three == 3 rotary_seqlen, rotary_dim = cos.shape rotary_dim *= 2 assert rotary_dim <= headdim cos_k = cos if cos_k is None else cos_k sin_k = sin if sin_k is None else sin_k assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False) k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary( k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False ) ctx.save_for_backward(cos, sin, cos_k, sin_k) return qkv @staticmethod def backward(ctx, dqkv): cos, sin, cos_k, sin_k = ctx.saved_tensors rotary_dim = cos.shape[-1] rotary_dim *= 2 dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary( dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True ) dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary( dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True ) return dqkv, None, None, None, None apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply class RotaryEmbedding(torch.nn.Module): """ The rotary position embeddings from RoFormer_ (Su et. al). A crucial insight from the method is that the query and keys are transformed by rotation matrices which depend on the relative positions. Other implementations are available in the Rotary Transformer repo_ and in GPT-NeoX_, GPT-NeoX was an inspiration .. _RoFormer: https://arxiv.org/abs/2104.09864 .. _repo: https://github.com/ZhuiyiTechnology/roformer .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ def __init__(self, dim: int, base=10000, scale_base=0, device=None): """ """ super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq) self.scale_base = scale_base scale = ( (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None ) self.register_buffer("scale", scale) self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None def _update_cos_sin_cache(self, x, indexes): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" if not isinstance(indexes, int): seqlen = indexes.max().item() + 1 else: seqlen = indexes + 1 # eval_forward # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: self._seq_len_cached = seqlen t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) if self.scale is None: self._cos_cached = torch.cos(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(x.dtype) else: power = ( torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 ) / self.scale_base scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: self._update_cos_sin_cache(qkv, indexes) if self.scale is None: return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) else: return apply_rotary_emb_qkv_( qkv, self._cos_cached[indexes], self._sin_cached[indexes], self._cos_k_cached[indexes], self._sin_k_cached[indexes], ) def eval_forward(self, qkv, seqlen_offset=0): """ seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) if self.scale is None: return legacy_apply_rotary_embed_qkv( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] ) else: return legacy_apply_rotary_embed_qkv( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], )