mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
5.6 KiB
174 lines
5.6 KiB
7 months ago
|
import torch
|
||
|
|
||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||
|
from colossalai.kernel.triton import flash_decoding_attention
|
||
|
from colossalai.utils import get_current_device
|
||
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||
|
generate_caches_and_block_tables_v2,
|
||
|
generate_caches_and_block_tables_vllm,
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
import triton # noqa
|
||
|
except ImportError:
|
||
|
print("please install triton from https://github.com/openai/triton")
|
||
|
|
||
|
inference_ops = InferenceOpsLoader().load()
|
||
|
|
||
|
# Triton benchmark plot attributions
|
||
|
configs = [
|
||
|
triton.testing.Benchmark(
|
||
|
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
|
||
|
x_vals=[2**i for i in range(3, 8)],
|
||
|
line_arg="provider",
|
||
|
line_vals=[
|
||
|
"vllm_paged_decoding_attention",
|
||
|
"triton_flash_decoding_attention",
|
||
|
"cuda_flash_decoding_attention",
|
||
|
],
|
||
|
line_names=[
|
||
|
"vllm_paged_decoding_attention",
|
||
|
"triton_flash_decoding_attention",
|
||
|
"cuda_flash_decoding_attention",
|
||
|
],
|
||
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
|
||
|
ylabel="ms",
|
||
|
plot_name=f"FlashDecodingAttention benchmarking results",
|
||
|
args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2},
|
||
|
)
|
||
|
]
|
||
|
|
||
|
|
||
|
def prepare_data(
|
||
|
BATCH_SIZE: int,
|
||
|
HEAD_SIZE: int,
|
||
|
NUM_ATTN_HEADS: int,
|
||
|
NUM_KV_HEADS: int,
|
||
|
MAX_SEQ_LEN: int,
|
||
|
dtype=torch.float16,
|
||
|
device="cuda",
|
||
|
):
|
||
|
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
||
|
# otherwise generate random context lengths.
|
||
|
# returns
|
||
|
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
|
||
|
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
|
||
|
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
|
||
|
num_tokens = torch.sum(kv_lengths).item()
|
||
|
|
||
|
q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE)
|
||
|
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
|
||
|
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
|
||
|
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||
|
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
|
||
|
|
||
|
return q, k_unpad, v_unpad, kv_lengths
|
||
|
|
||
|
|
||
|
@triton.testing.perf_report(configs)
|
||
|
def benchmark_flash_decoding_attention(
|
||
|
provider: str,
|
||
|
BATCH_SIZE: int,
|
||
|
BLOCK_SIZE: int,
|
||
|
MAX_NUM_BLOCKS_PER_SEQ: int,
|
||
|
HEAD_SIZE: int,
|
||
|
KV_GROUP_NUM: int,
|
||
|
):
|
||
|
try:
|
||
|
from vllm._C import ops as vllm_ops
|
||
|
except ImportError:
|
||
|
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
||
|
|
||
|
warmup = 10
|
||
|
rep = 1000
|
||
|
|
||
|
dtype = torch.float16
|
||
|
|
||
|
NUM_ATTN_HEADS = 16
|
||
|
|
||
|
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
|
||
|
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
|
||
|
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
|
||
|
device = get_current_device()
|
||
|
|
||
|
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||
|
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||
|
)
|
||
|
|
||
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||
|
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||
|
)
|
||
|
|
||
|
vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm(
|
||
|
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
|
||
|
)
|
||
|
|
||
|
block_tables = block_tables.to(device=device)
|
||
|
max_seq_len_across_batch = kv_seq_lengths.max().item()
|
||
|
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||
|
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
|
||
|
sm_scale = 1.0 / (HEAD_SIZE**0.5)
|
||
|
|
||
|
mid_output = torch.empty(
|
||
|
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
|
||
|
)
|
||
|
mid_output_lse = torch.empty(
|
||
|
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
|
||
|
)
|
||
|
|
||
|
if provider == "vllm_paged_decoding_attention":
|
||
|
alibi_slopes = None
|
||
|
fn = lambda: vllm_ops.paged_attention_v1(
|
||
|
output,
|
||
|
q.squeeze(2),
|
||
|
vllm_k_cache,
|
||
|
vllm_v_cache,
|
||
|
NUM_KV_HEADS,
|
||
|
sm_scale,
|
||
|
block_tables,
|
||
|
kv_seq_lengths,
|
||
|
BLOCK_SIZE,
|
||
|
max_seq_len_across_batch,
|
||
|
alibi_slopes,
|
||
|
"auto",
|
||
|
)
|
||
|
elif provider == "triton_flash_decoding_attention":
|
||
|
fn = lambda: flash_decoding_attention(
|
||
|
q.squeeze(2),
|
||
|
k_cache,
|
||
|
v_cache,
|
||
|
kv_seq_lengths,
|
||
|
block_tables,
|
||
|
BLOCK_SIZE,
|
||
|
max_seq_len_across_batch,
|
||
|
output,
|
||
|
mid_output,
|
||
|
mid_output_lse,
|
||
|
sm_scale=sm_scale,
|
||
|
kv_group_num=KV_GROUP_NUM,
|
||
|
) # [bsz, 1, num_heads, head_dim]
|
||
|
elif provider == "cuda_flash_decoding_attention":
|
||
|
fn = lambda: inference_ops.flash_decoding_attention(
|
||
|
output,
|
||
|
q.squeeze(2),
|
||
|
k_cache,
|
||
|
v_cache,
|
||
|
kv_seq_lengths,
|
||
|
block_tables,
|
||
|
BLOCK_SIZE,
|
||
|
max_seq_len_across_batch,
|
||
|
mid_output,
|
||
|
mid_output_lse,
|
||
|
sm_scale,
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError("Undefined provider.")
|
||
|
|
||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||
|
|
||
|
return ms
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
benchmark_flash_decoding_attention.run(save_path=".", print_data=True)
|