mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Benchmarking rotary embedding and add a fetch function (#5277)
* fix bugs and add a cos/sin cache fetch func * add docstring * fix bug * fixpull/5306/head
parent
b7853196a0
commit
8e606ecc7e
|
@ -1,9 +1,20 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import rotary_embedding
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
|
@ -52,5 +63,52 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
BATCH = 16
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[2**i for i in range(4, 11)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||
args={"num_kv_heads": 16},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_rotary_emb(
|
||||
provider: str,
|
||||
num_tokens: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
warmup = 10
|
||||
rep = 100
|
||||
|
||||
head_dim = 128
|
||||
dtype = torch.float16
|
||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (num_tokens, head_dim // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
|
||||
if provider == "torch_rotary_emb_func":
|
||||
fn = lambda: torch_rotary_emb(q, cos, sin)
|
||||
elif provider == "triton_rotary_emb_func":
|
||||
fn = lambda: rotary_embedding(q, k, cos, sin)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||
# benchmark_rotary_emb.run(save_path=".",print_data=True)
|
||||
|
|
Loading…
Reference in New Issue