[kernel] Add RMSLayerNorm triton kernel (#5262)

* add layerrmsnorm triton kernel

* add layerrmsnorm kernel

* modify the atol and rtol in test file

* Remove the logics of mean computations, and update the name of ther kernel functions and files

* add benchmark of rms norm
pull/5282/head
Yaozheng Fang 2024-01-18 10:21:03 +08:00 committed by GitHub
parent 86b63f720c
commit 5ae9099f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 103 additions and 62 deletions

View File

@ -10,7 +10,7 @@ except ImportError:
if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_fwd
from .fused_layernorm import layer_norm
from .rms_layernorm import rms_layernorm
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import rotary_embedding
@ -21,7 +21,7 @@ if HAS_TRITON:
"flash_decoding_fwd",
"copy_kv_to_blocked_cache",
"softmax",
"layer_norm",
"rms_layernorm",
"gptq_fused_linear_triton",
"rotary_embedding",
]

View File

@ -14,34 +14,28 @@ if HAS_TRITON:
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
@triton.jit
def _layer_norm_fwd_fused(
def _rmsnorm_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
# Compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.0)
x = tl.where(cols < N, x, 0.0)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
@ -50,15 +44,14 @@ if HAS_TRITON:
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
@torch.no_grad()
def layer_norm(x, weight, bias, eps):
def rms_layernorm(x, weight, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
@ -72,7 +65,7 @@ if HAS_TRITON:
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](
x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
_rmsnorm_kernel[(M,)](
x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y

View File

@ -1,43 +0,0 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import layer_norm
from colossalai.testing.utils import parameterize
try:
pass
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@parameterize("M", [2, 4, 8, 16])
@parameterize("N", [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
bias = torch.rand(w_shape, dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
y_triton = layer_norm(x, weight, bias, eps)
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
assert y_triton.shape == y_torch.shape
assert y_triton.dtype == y_torch.dtype
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_layer_norm()

View File

@ -0,0 +1,91 @@
import pytest
import torch
from packaging import version
import triton
from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize
from transformers.models.llama.modeling_llama import LlamaRMSNorm
try:
pass
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@parameterize("M", [2, 4, 8, 16])
@parameterize("N", [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
y_triton = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype)
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
# Triton benchmark plot attributions
configs = [
triton.testing.Benchmark(
x_names=["SEQUENCE_TOTAL"],
x_vals=[i for i in range(128, 1025, 128)],
line_arg="provider",
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"],
line_names=["llama_rms_layernorm", "triton_rms_layernorm"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
)
]
@triton.testing.perf_report(configs)
def benchmark_rms_layernorm(
provider: str,
SEQUENCE_TOTAL: int,
HIDDEN_SIZE: int,
):
warmup = 10
rep = 100
dtype = torch.float16
eps = 1e-5
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
if provider == "llama_rms_layernorm":
fn = lambda: rms_norm.forward(x).to(dtype)
elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps)
else:
raise ValueError("Undefined provider.")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if __name__ == "__main__":
test_layer_norm()
# benchmark_rms_layernorm.run(save_path=".")