mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
4.0 KiB
111 lines
4.0 KiB
import torch
|
|
|
|
from colossalai.kernel.triton import flash_decoding_attention
|
|
from colossalai.utils import get_current_device
|
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
|
convert_kv_unpad_to_padded,
|
|
generate_caches_and_block_tables_v2,
|
|
prepare_padding_mask,
|
|
torch_attn_ref,
|
|
)
|
|
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
|
|
|
|
try:
|
|
import triton # noqa
|
|
|
|
except ImportError:
|
|
print("please install triton from https://github.com/openai/triton")
|
|
|
|
Q_LEN = 1
|
|
HEAD_DIM = 128
|
|
BATCH = 16
|
|
BLOCK_SIZE = 32
|
|
SAME_LEN = True
|
|
WARM_UPS = 10
|
|
REPS = 100
|
|
configs = [
|
|
triton.testing.Benchmark(
|
|
x_names=["KV_LEN"],
|
|
x_vals=[2**i for i in range(8, 14)],
|
|
# x_vals=[x for x in range(256, 8192, 256)],
|
|
line_arg="provider",
|
|
line_vals=["torch", "triton"],
|
|
line_names=["Torch", "Triton"],
|
|
styles=[("red", "-"), ("blue", "-")],
|
|
ylabel="ms",
|
|
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
|
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
|
)
|
|
]
|
|
|
|
|
|
@triton.testing.perf_report(configs)
|
|
def bench_kernel(
|
|
bsz,
|
|
KV_LEN,
|
|
provider,
|
|
block_size: int,
|
|
kv_group_num: int,
|
|
same_context_len: bool,
|
|
):
|
|
num_attn_heads = 16
|
|
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
|
|
max_seq_len = block_size * max_num_blocks_per_seq
|
|
|
|
num_kv_heads = num_attn_heads // kv_group_num
|
|
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
|
block_size * max_num_blocks_per_seq
|
|
dtype = torch.float16
|
|
device = get_current_device()
|
|
|
|
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
|
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
|
)
|
|
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
if provider == "torch":
|
|
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
|
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
|
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
|
|
fn = lambda: torch_attn_ref(
|
|
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
|
)
|
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
if provider == "triton":
|
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
|
)
|
|
block_tables = block_tables.to(device=device)
|
|
# the maximum block length splitted on kv should be the kv cache block size
|
|
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
|
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
|
mid_output = torch.empty(
|
|
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
|
)
|
|
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
|
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
|
fn = lambda: flash_decoding_attention(
|
|
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
|
# refer to attention forward in modeling.
|
|
q.squeeze(2),
|
|
k_cache,
|
|
v_cache,
|
|
kv_lengths,
|
|
block_tables,
|
|
block_size,
|
|
max_seq_len_in_b,
|
|
output,
|
|
mid_output,
|
|
mid_output_lse,
|
|
sm_scale=sm_scale,
|
|
kv_group_num=kv_group_num,
|
|
) # [bsz, 1, num_heads, head_dim]
|
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
|
|
return ms, min_ms, max_ms
|
|
|
|
|
|
if __name__ == "__main__":
|
|
bench_kernel.run(save_path=".", print_data=True)
|