mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Kernel: no pad rotary embedding (#5252)
* fix bugs * comment * use more accurate atol * fixpull/5258/head
parent
d40eb26029
commit
fded91d049
|
@ -11,6 +11,7 @@ if HAS_TRITON:
|
||||||
from .context_attn_unpad import context_attention_unpadded
|
from .context_attn_unpad import context_attention_unpadded
|
||||||
from .fused_layernorm import layer_norm
|
from .fused_layernorm import layer_norm
|
||||||
from .gptq_triton import gptq_fused_linear_triton
|
from .gptq_triton import gptq_fused_linear_triton
|
||||||
|
from .no_pad_rotary_embedding import rotary_embedding
|
||||||
from .softmax import softmax
|
from .softmax import softmax
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -18,4 +19,5 @@ if HAS_TRITON:
|
||||||
"softmax",
|
"softmax",
|
||||||
"layer_norm",
|
"layer_norm",
|
||||||
"gptq_fused_linear_triton",
|
"gptq_fused_linear_triton",
|
||||||
|
"rotary_embedding",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def rotary_embedding_kernel(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
q_total_tokens,
|
||||||
|
Q_HEAD_NUM: tl.constexpr,
|
||||||
|
K_HEAD_NUM: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_HEAD: tl.constexpr,
|
||||||
|
BLOCK_TOKENS: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_head_index = tl.program_id(0)
|
||||||
|
block_token_index = tl.program_id(1)
|
||||||
|
|
||||||
|
rotary_data = q
|
||||||
|
HEAD_NUM = Q_HEAD_NUM
|
||||||
|
head_stride = q_head_stride
|
||||||
|
token_stride = q_token_stride
|
||||||
|
|
||||||
|
if block_token_index * BLOCK_TOKENS >= q_total_tokens:
|
||||||
|
block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS)
|
||||||
|
rotary_data = k
|
||||||
|
HEAD_NUM = K_HEAD_NUM
|
||||||
|
head_stride = k_head_stride
|
||||||
|
token_stride = k_token_stride
|
||||||
|
|
||||||
|
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
||||||
|
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||||
|
|
||||||
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||||
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||||
|
|
||||||
|
off_data0 = (
|
||||||
|
tokens_range[:, None, None] * token_stride
|
||||||
|
+ head_range[None, :, None] * head_stride
|
||||||
|
+ dim_range0[None, None, :] * head_dim_stride
|
||||||
|
)
|
||||||
|
off_data1 = (
|
||||||
|
tokens_range[:, None, None] * token_stride
|
||||||
|
+ head_range[None, :, None] * head_stride
|
||||||
|
+ dim_range1[None, None, :] * head_dim_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_data0 = tl.load(
|
||||||
|
rotary_data + off_data0,
|
||||||
|
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
loaded_data1 = tl.load(
|
||||||
|
rotary_data + off_data1,
|
||||||
|
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
|
||||||
|
|
||||||
|
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||||
|
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||||
|
|
||||||
|
out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :]
|
||||||
|
out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :]
|
||||||
|
|
||||||
|
# concat
|
||||||
|
tl.store(
|
||||||
|
rotary_data + off_data0,
|
||||||
|
out0,
|
||||||
|
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
rotary_data + off_data1,
|
||||||
|
out1,
|
||||||
|
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def rotary_embedding(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
q: query tensor, [total_tokens, head_num, head_dim]
|
||||||
|
k: key tensor, [total_tokens, head_num, head_dim]
|
||||||
|
cos: cosine for rotary embedding, [total_tokens, head_dim]
|
||||||
|
sin: sine for rotary embedding, [total_tokens, head_dim]
|
||||||
|
"""
|
||||||
|
q_total_tokens, q_head_num, head_dim = q.shape
|
||||||
|
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||||
|
BLOCK_HEAD = 4
|
||||||
|
BLOCK_TOKENS = 8
|
||||||
|
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
|
||||||
|
|
||||||
|
if head_dim >= 128:
|
||||||
|
num_warps = 8
|
||||||
|
else:
|
||||||
|
num_warps = 4
|
||||||
|
|
||||||
|
q_token_stride = q.stride(0)
|
||||||
|
q_head_stride = q.stride(1)
|
||||||
|
head_dim_stride = q.stride(2)
|
||||||
|
|
||||||
|
k_token_stride = k.stride(0)
|
||||||
|
k_head_stride = k.stride(1)
|
||||||
|
|
||||||
|
k_head_num = q.shape[1]
|
||||||
|
|
||||||
|
cos_token_stride = cos.stride(0)
|
||||||
|
cos_stride = cos.stride(1)
|
||||||
|
|
||||||
|
rotary_embedding_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
q_total_tokens,
|
||||||
|
Q_HEAD_NUM=q_head_num,
|
||||||
|
K_HEAD_NUM=k_head_num,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
BLOCK_HEAD=BLOCK_HEAD,
|
||||||
|
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
|
@ -0,0 +1,56 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import rotary_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def torch_rotary_emb(x, cos, sin):
|
||||||
|
seq_len, h, dim = x.shape
|
||||||
|
x0 = x[:, :, 0 : dim // 2]
|
||||||
|
x1 = x[:, :, dim // 2 : dim]
|
||||||
|
cos = cos.view((seq_len, 1, dim // 2))
|
||||||
|
sin = sin.view((seq_len, 1, dim // 2))
|
||||||
|
o0 = x0 * cos - x1 * sin
|
||||||
|
o1 = x0 * sin + x1 * cos
|
||||||
|
return torch.cat((o0, o1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||||
|
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||||
|
@pytest.mark.parametrize("H", [32])
|
||||||
|
@pytest.mark.parametrize("D", [64])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
|
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||||
|
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||||
|
# our crafted op equals to Transformers
|
||||||
|
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||||
|
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||||
|
emb = LlamaRotaryEmbedding(D)
|
||||||
|
cos, sin = emb(x0, TOTAL_TOKENS)
|
||||||
|
cos_2 = cos[:, :32]
|
||||||
|
sin_2 = sin[:, :32]
|
||||||
|
position_ids = torch.arange(TOTAL_TOKENS)
|
||||||
|
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
||||||
|
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
||||||
|
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||||
|
|
||||||
|
# create data
|
||||||
|
q_shape = (TOTAL_TOKENS, H, D)
|
||||||
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
|
k_shape = (TOTAL_TOKENS, H, D)
|
||||||
|
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
q_ref = torch_rotary_emb(q, cos, sin)
|
||||||
|
k_ref = torch_rotary_emb(k, cos, sin)
|
||||||
|
rotary_embedding(q, k, cos, sin)
|
||||||
|
|
||||||
|
assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4)
|
||||||
|
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
Loading…
Reference in New Issue