mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Kernel: no pad rotary embedding (#5252)
* fix bugs * comment * use more accurate atol * fixpull/5258/head
parent
d40eb26029
commit
fded91d049
|
@ -11,6 +11,7 @@ if HAS_TRITON:
|
|||
from .context_attn_unpad import context_attention_unpadded
|
||||
from .fused_layernorm import layer_norm
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .no_pad_rotary_embedding import rotary_embedding
|
||||
from .softmax import softmax
|
||||
|
||||
__all__ = [
|
||||
|
@ -18,4 +19,5 @@ if HAS_TRITON:
|
|||
"softmax",
|
||||
"layer_norm",
|
||||
"gptq_fused_linear_triton",
|
||||
"rotary_embedding",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rotary_embedding_kernel(
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
k_head_stride,
|
||||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_stride,
|
||||
q_total_tokens,
|
||||
Q_HEAD_NUM: tl.constexpr,
|
||||
K_HEAD_NUM: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
):
|
||||
block_head_index = tl.program_id(0)
|
||||
block_token_index = tl.program_id(1)
|
||||
|
||||
rotary_data = q
|
||||
HEAD_NUM = Q_HEAD_NUM
|
||||
head_stride = q_head_stride
|
||||
token_stride = q_token_stride
|
||||
|
||||
if block_token_index * BLOCK_TOKENS >= q_total_tokens:
|
||||
block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS)
|
||||
rotary_data = k
|
||||
HEAD_NUM = K_HEAD_NUM
|
||||
head_stride = k_head_stride
|
||||
token_stride = k_token_stride
|
||||
|
||||
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
||||
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_data0 = (
|
||||
tokens_range[:, None, None] * token_stride
|
||||
+ head_range[None, :, None] * head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_data1 = (
|
||||
tokens_range[:, None, None] * token_stride
|
||||
+ head_range[None, :, None] * head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
|
||||
loaded_data0 = tl.load(
|
||||
rotary_data + off_data0,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
loaded_data1 = tl.load(
|
||||
rotary_data + off_data1,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
|
||||
|
||||
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||
|
||||
out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :]
|
||||
out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :]
|
||||
|
||||
# concat
|
||||
tl.store(
|
||||
rotary_data + off_data0,
|
||||
out0,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
rotary_data + off_data1,
|
||||
out1,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rotary_embedding(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q: query tensor, [total_tokens, head_num, head_dim]
|
||||
k: key tensor, [total_tokens, head_num, head_dim]
|
||||
cos: cosine for rotary embedding, [total_tokens, head_dim]
|
||||
sin: sine for rotary embedding, [total_tokens, head_dim]
|
||||
"""
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_TOKENS = 8
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
|
||||
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
q_token_stride = q.stride(0)
|
||||
q_head_stride = q.stride(1)
|
||||
head_dim_stride = q.stride(2)
|
||||
|
||||
k_token_stride = k.stride(0)
|
||||
k_head_stride = k.stride(1)
|
||||
|
||||
k_head_num = q.shape[1]
|
||||
|
||||
cos_token_stride = cos.stride(0)
|
||||
cos_stride = cos.stride(1)
|
||||
|
||||
rotary_embedding_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
k_head_stride,
|
||||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_stride,
|
||||
q_total_tokens,
|
||||
Q_HEAD_NUM=q_head_num,
|
||||
K_HEAD_NUM=k_head_num,
|
||||
HEAD_DIM=head_dim,
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import rotary_embedding
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
x0 = x[:, :, 0 : dim // 2]
|
||||
x1 = x[:, :, dim // 2 : dim]
|
||||
cos = cos.view((seq_len, 1, dim // 2))
|
||||
sin = sin.view((seq_len, 1, dim // 2))
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
return torch.cat((o0, o1), dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||
@pytest.mark.parametrize("H", [32])
|
||||
@pytest.mark.parametrize("D", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
# our crafted op equals to Transformers
|
||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||
emb = LlamaRotaryEmbedding(D)
|
||||
cos, sin = emb(x0, TOTAL_TOKENS)
|
||||
cos_2 = cos[:, :32]
|
||||
sin_2 = sin[:, :32]
|
||||
position_ids = torch.arange(TOTAL_TOKENS)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
||||
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||
|
||||
# create data
|
||||
q_shape = (TOTAL_TOKENS, H, D)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (TOTAL_TOKENS, H, D)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
|
||||
q_ref = torch_rotary_emb(q, cos, sin)
|
||||
k_ref = torch_rotary_emb(k, cos, sin)
|
||||
rotary_embedding(q, k, cos, sin)
|
||||
|
||||
assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
Loading…
Reference in New Issue