[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)

* fix decoding kernel pytest

* revise and add triton context attn benchmark
pull/5306/head
Yuanheng Zhao 2024-01-23 17:16:02 +08:00 committed by GitHub
parent 8e606ecc7e
commit 3da9993b0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 116 additions and 15 deletions

View File

@ -87,7 +87,7 @@ class PagedAttention:
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
"""
bsz = len(seq_lengths)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size, dtype=tensor.dtype)
token_idx = 0
for i, seq_len in enumerate(seq_lengths):

View File

@ -5,6 +5,8 @@
#
# Inspired and modified from Triton Tutorial - Fused Attention
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
from typing import Optional
import torch
import triton
import triton.language as tl
@ -190,13 +192,8 @@ def context_attention_unpadded(
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
max_seq_len_in_b: Optional[int] = None,
):
# q/k in context stage are supposed to be put into k_cache and v_cache.
# This step can be optimized in future.
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk == Lv
assert Lk in {32, 64, 128, 256}
@ -210,7 +207,7 @@ def context_attention_unpadded(
num_kv_group = num_heads // num_kv_heads
num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item()
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
sm_scale = 1.0 / (Lq**0.5)
output = torch.zeros_like(q)
@ -220,7 +217,7 @@ def context_attention_unpadded(
assert block_size in {16, 32, 64, 128}
BLOCK_M = BLOCK_N = block_size
grid = (num_seqs, num_heads, triton.cdiv(max_seq_len, BLOCK_M))
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
_fwd_context_paged_attention_kernel[grid](
q,

View File

@ -215,10 +215,9 @@ def flash_decoding_attention(
Returns:
Output tensor with shape [bsz, num_heads, q_len, head_dim]
"""
if q.dim() == 3:
bsz, num_heads, head_dim = q.shape
else:
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
bsz, num_heads, head_dim = q.shape
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (

View File

@ -1,7 +1,9 @@
import pytest
import torch
from packaging import version
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref
@ -89,6 +91,7 @@ def test_context_attention(
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous()
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
@ -109,5 +112,103 @@ def test_context_attention(
assert torch.equal(v_cache_ref, v_cache_triton)
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [
triton.testing.Benchmark(
x_names=["KV_LEN"],
x_vals=[2**i for i in range(8, 13)],
# x_vals=[x for x in range(256, 8192, 256)],
line_arg="provider",
line_vals=["torch", "triton"],
line_names=["Torch", "Triton"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}",
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
)
]
@triton.testing.perf_report(configs)
def bench_kernel(
bsz,
KV_LEN,
provider,
block_size: int,
kv_group_num: int,
same_context_len: bool,
):
num_attn_heads = 16
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
max_seq_len = block_size * max_num_blocks_per_seq
num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous()
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM)
k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
q_padded, k_padded, v_padded = (
q_padded.to(device=device),
k_padded.to(device=device),
v_padded.to(device=device),
)
q_padded = q_padded.transpose(1, 2)
k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num)
v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num)
# This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0
)
attn_mask = attn_mask.to(device=q_padded.device)
fn = lambda: torch_attn_ref(
q_padded,
k_padded,
v_padded,
attn_mask,
bsz,
max_seq_len,
max_seq_len,
num_attn_heads,
num_kv_heads,
HEAD_DIM,
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton":
k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref)
fn = lambda: context_attention_unpadded(
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms
if __name__ == "__main__":
test_context_attention(4, 32, 8, 16, 1, True)
# bench_kernel.run(save_path=".", print_data=True)

View File

@ -97,7 +97,9 @@ def test_flash_decoding(
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (HEAD_DIM**0.5)
out_triton = flash_decoding_attention(
q,
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling.
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
@ -188,7 +190,9 @@ def bench_kernel(
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (HEAD_DIM**0.5)
fn = lambda: flash_decoding_attention(
q,
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling.
q.squeeze(2),
k_cache,
v_cache,
kv_lengths,