|
|
|
@ -1,11 +1,12 @@
|
|
|
|
|
import pytest |
|
|
|
|
import torch |
|
|
|
|
from packaging import version |
|
|
|
|
import triton |
|
|
|
|
from packaging import version |
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
|
|
from vllm.model_executor.layers.layernorm import RMSNorm |
|
|
|
|
|
|
|
|
|
from colossalai.kernel.triton import rms_layernorm |
|
|
|
|
from colossalai.testing.utils import parameterize |
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
pass |
|
|
|
@ -24,7 +25,6 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
|
|
|
@parameterize("M", [2, 4, 8, 16]) |
|
|
|
|
@parameterize("N", [64, 128]) |
|
|
|
|
def test_layer_norm(M, N): |
|
|
|
|
|
|
|
|
|
dtype = torch.float16 |
|
|
|
|
eps = 1e-5 |
|
|
|
|
x_shape = (M, N) |
|
|
|
@ -39,15 +39,14 @@ def test_layer_norm(M, N):
|
|
|
|
|
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Triton benchmark plot attributions |
|
|
|
|
configs = [ |
|
|
|
|
triton.testing.Benchmark( |
|
|
|
|
x_names=["SEQUENCE_TOTAL"], |
|
|
|
|
x_vals=[i for i in range(128, 1025, 128)], |
|
|
|
|
line_arg="provider", |
|
|
|
|
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"], |
|
|
|
|
line_names=["llama_rms_layernorm", "triton_rms_layernorm"], |
|
|
|
|
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"], |
|
|
|
|
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"], |
|
|
|
|
styles=[("red", "-"), ("blue", "-")], |
|
|
|
|
ylabel="ms", |
|
|
|
|
plot_name=f"RMSNorm benchmarking results", |
|
|
|
@ -63,18 +62,17 @@ def benchmark_rms_layernorm(
|
|
|
|
|
HIDDEN_SIZE: int, |
|
|
|
|
): |
|
|
|
|
warmup = 10 |
|
|
|
|
rep = 100 |
|
|
|
|
rep = 1000 |
|
|
|
|
|
|
|
|
|
dtype = torch.float16 |
|
|
|
|
eps = 1e-5 |
|
|
|
|
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) |
|
|
|
|
w_shape = (x_shape[-1],) |
|
|
|
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda") |
|
|
|
|
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda() |
|
|
|
|
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") |
|
|
|
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") |
|
|
|
|
|
|
|
|
|
if provider == "llama_rms_layernorm": |
|
|
|
|
fn = lambda: rms_norm.forward(x).to(dtype) |
|
|
|
|
if provider == "vllm_rms_layernorm": |
|
|
|
|
fn = lambda: vllm_norm(x) |
|
|
|
|
elif provider == "triton_rms_layernorm": |
|
|
|
|
fn = lambda: rms_layernorm(x, weight, eps=eps) |
|
|
|
|
else: |
|
|
|
@ -83,9 +81,8 @@ def benchmark_rms_layernorm(
|
|
|
|
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
|
|
|
|
|
|
|
|
return ms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
test_layer_norm() |
|
|
|
|
# benchmark_rms_layernorm.run(save_path=".") |
|
|
|
|
# benchmark_rms_layernorm.run(save_path=".", print_data=True) |
|
|
|
|