diff --git a/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py index eeb125776..d611234f0 100644 --- a/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py @@ -1,9 +1,20 @@ import pytest import torch +from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.triton import rotary_embedding +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + def torch_rotary_emb(x, cos, sin): seq_len, h, dim = x.shape @@ -52,5 +63,52 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4) +BATCH = 16 +configs = [ + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[2**i for i in range(4, 11)], + line_arg="provider", + line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], + line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"rotary_emb-batch-{BATCH}", + args={"num_kv_heads": 16}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_rotary_emb( + provider: str, + num_tokens: int, + num_kv_heads: int, +): + warmup = 10 + rep = 100 + + head_dim = 128 + dtype = torch.float16 + q_shape = (num_tokens, num_kv_heads, head_dim) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (num_tokens, num_kv_heads, head_dim) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (num_tokens, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + if provider == "torch_rotary_emb_func": + fn = lambda: torch_rotary_emb(q, cos, sin) + elif provider == "triton_rotary_emb_func": + fn = lambda: rotary_embedding(q, k, cos, sin) + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + if __name__ == "__main__": test_rotary_emb(4, 64, 32, 64, torch.float32) + # benchmark_rotary_emb.run(save_path=".",print_data=True)