[Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
pull/4992/head
Cuiqing Li 2023-10-30 14:04:37 +08:00 committed by GitHub
parent cf579ff46d
commit 459a88c806
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 226 additions and 374 deletions

View File

@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with
- [x] policy
- [x] context forward
- [x] token forward
- [x] support flash-decoding
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Llama
- [x] Llama-2
- [x] Bloom
- [ ] Chatglm2
- [x] Chatglm2
- [ ] Benchmarking for all models
## Get started
@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
# also, install xformers from source:
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```
### Docker
@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
# install xformers from source
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```
### Dive into fast-inference!

View File

@ -311,6 +311,7 @@ class TPInferEngine:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")

View File

@ -19,6 +19,12 @@ from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
try:
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
def generate_alibi(n_head, dtype=torch.float16):
"""
@ -460,7 +466,10 @@ class BloomInferenceForwards:
# output = self.output[:batch_size*q_length, :, :]
output = torch.empty_like(q)
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
if HAS_LIGHTLLM_KERNEL:
lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
else:
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
else:

View File

@ -1,4 +1,6 @@
from typing import List, Optional, Tuple
import math
import copy
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
@ -10,24 +12,11 @@ from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttention
from ._utils import copy_kv_to_mem_cache
try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
)
HAS_VLLM_KERNERL = False
try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True
@ -35,6 +24,13 @@ except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@ -54,6 +50,71 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
class LlamaInferenceForwards:
"""
@ -204,7 +265,8 @@ class LlamaInferenceForwards:
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
@ -247,6 +309,7 @@ class LlamaInferenceForwards:
outputs += (present_key_value,)
return outputs
@staticmethod
def llama_flash_attn_kvcache_forward(
@ -295,27 +358,8 @@ class LlamaInferenceForwards:
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
if self.num_key_value_groups == 1:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)
else:
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)
llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@ -337,35 +381,26 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
if self.num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)
HAS_LIGHTLLM_KERNEL = False
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
heads_per_group = self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
attn_output = flash_attn_with_kvcache(q = query_states,
k_cache = copy_cache_k,
v_cache = copy_cache_v,
softmax_scale = 1/ math.sqrt(self.head_dim),
causal = True)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@ -374,22 +409,3 @@ class LlamaInferenceForwards:
# return past_key_value as None
return attn_output, None, None
def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
return _vllm_rmsnorm_forward
else:
return None

View File

@ -9,7 +9,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
from ..modeling.llama import LlamaInferenceForwards
try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
@ -105,9 +105,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}

View File

@ -5,7 +5,6 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@ -155,39 +154,43 @@ if HAS_TRITON:
num_warps = 4 if Lk <= 64 else 8
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
alibi,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
if triton.__version__ < "2.1.0":
_context_flash_attention_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
alibi,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
return
@torch.no_grad()
@ -207,36 +210,40 @@ if HAS_TRITON:
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
_context_flash_attention_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
if triton.__version__ < "2.1.0":
_context_flash_attention_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0")
return

View File

@ -105,8 +105,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
parser.add_argument(
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]

View File

@ -10,6 +10,12 @@ from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try:
import lightllm
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
TP_SIZE = 2
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
@ -52,7 +58,7 @@ def check_bloom(rank, world_size, port):
run()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -12,6 +12,12 @@ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import Ch
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try:
import lightllm
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1
BATCH_SIZE = 8
@ -61,7 +67,7 @@ def check_chatglm2(rank, world_size, port):
run_chatglm2_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -12,6 +12,12 @@ from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try:
import lightllm
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 2
BATCH_SIZE = 8
@ -57,7 +63,7 @@ def check_llama(rank, world_size, port):
run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -12,6 +12,12 @@ from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try:
import lightllm
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 2
BATCH_SIZE = 8
@ -55,7 +61,7 @@ def check_llama(rank, world_size, port):
run_llama_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -1,60 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
from torch import nn
try:
from vllm import layernorm_ops
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("please install vllm kernels to install rmsnorm")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
def test_rmsnorm():
data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
hg_rms = LlamaRMSNorm(64)
hg_rms = hg_rms.half().cuda()
out_torch = hg_rms(data)
out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
if __name__ == "__main__":
test_rmsnorm()

View File

@ -1,153 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Tuple
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
def run_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
rotary_dim: int,
dtype: torch.dtype,
base: int = 10000,
) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device="cuda")
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda")
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
rotary_embedding_neox(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device="cuda")
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
def test_rotary_embedding():
run_rotary_embedding_neox(
num_tokens=1024,
num_heads=8,
head_size=64,
max_position=8192,
rotary_dim=64,
dtype=torch.float16,
)
if __name__ == "__main__":
test_rotary_embedding()