mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Update rms norm kernel, benchmark with vLLM (#5315)
* add * xi * del * del * fixpull/5280/head
parent
7ddd8b37f0
commit
1f8a75d470
|
@ -23,7 +23,6 @@ if HAS_TRITON:
|
||||||
eps, # epsilon to avoid division by zero
|
eps, # epsilon to avoid division by zero
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
|
||||||
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
|
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
|
||||||
|
|
||||||
# Map the program id to the row of X and Y it should compute.
|
# Map the program id to the row of X and Y it should compute.
|
||||||
|
@ -54,18 +53,19 @@ if HAS_TRITON:
|
||||||
def rms_layernorm(x, weight, eps):
|
def rms_layernorm(x, weight, eps):
|
||||||
# allocate output
|
# allocate output
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
# reshape input data into 2D tensor
|
# reshape input data into 2D tensor, (total token, hidden_size)
|
||||||
x_arg = x.reshape(-1, x.shape[-1])
|
x_arg = x.reshape(-1, x.shape[-1])
|
||||||
M, N = x_arg.shape
|
M, N = x_arg.shape
|
||||||
# Less than 64KB per feature: enqueue fused kernel
|
# Less than 64KB per feature: enqueue fused kernel
|
||||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||||
|
|
||||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||||
if N > BLOCK_SIZE:
|
if N > MAX_FUSED_SIZE:
|
||||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||||
|
|
||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
|
||||||
|
|
||||||
# enqueue kernel
|
# enqueue kernel
|
||||||
_rmsnorm_kernel[(M,)](
|
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||||
x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
||||||
)
|
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
|
||||||
import triton
|
import triton
|
||||||
|
from packaging import version
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
from colossalai.testing.utils import parameterize
|
from colossalai.testing.utils import parameterize
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pass
|
pass
|
||||||
|
@ -24,7 +25,6 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
@parameterize("M", [2, 4, 8, 16])
|
@parameterize("M", [2, 4, 8, 16])
|
||||||
@parameterize("N", [64, 128])
|
@parameterize("N", [64, 128])
|
||||||
def test_layer_norm(M, N):
|
def test_layer_norm(M, N):
|
||||||
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
x_shape = (M, N)
|
x_shape = (M, N)
|
||||||
|
@ -39,15 +39,14 @@ def test_layer_norm(M, N):
|
||||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Triton benchmark plot attributions
|
# Triton benchmark plot attributions
|
||||||
configs = [
|
configs = [
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["SEQUENCE_TOTAL"],
|
x_names=["SEQUENCE_TOTAL"],
|
||||||
x_vals=[i for i in range(128, 1025, 128)],
|
x_vals=[i for i in range(128, 1025, 128)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"],
|
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"],
|
||||||
line_names=["llama_rms_layernorm", "triton_rms_layernorm"],
|
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"],
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
|
@ -63,18 +62,17 @@ def benchmark_rms_layernorm(
|
||||||
HIDDEN_SIZE: int,
|
HIDDEN_SIZE: int,
|
||||||
):
|
):
|
||||||
warmup = 10
|
warmup = 10
|
||||||
rep = 100
|
rep = 1000
|
||||||
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||||
w_shape = (x_shape[-1],)
|
w_shape = (x_shape[-1],)
|
||||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda()
|
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
if provider == "vllm_rms_layernorm":
|
||||||
if provider == "llama_rms_layernorm":
|
fn = lambda: vllm_norm(x)
|
||||||
fn = lambda: rms_norm.forward(x).to(dtype)
|
|
||||||
elif provider == "triton_rms_layernorm":
|
elif provider == "triton_rms_layernorm":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
else:
|
else:
|
||||||
|
@ -83,9 +81,8 @@ def benchmark_rms_layernorm(
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
|
||||||
return ms
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_layer_norm()
|
test_layer_norm()
|
||||||
# benchmark_rms_layernorm.run(save_path=".")
|
# benchmark_rms_layernorm.run(save_path=".", print_data=True)
|
||||||
|
|
Loading…
Reference in New Issue