mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.5 KiB
72 lines
2.5 KiB
1 year ago
|
import torch
|
||
|
|
||
|
try:
|
||
|
import triton
|
||
|
import triton.language as tl
|
||
1 year ago
|
|
||
1 year ago
|
HAS_TRITON = True
|
||
|
except ImportError:
|
||
|
HAS_TRITON = False
|
||
|
print("please install triton from https://github.com/openai/triton")
|
||
|
|
||
|
if HAS_TRITON:
|
||
|
# CREDITS: These functions are adapted from the Triton tutorial
|
||
|
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||
|
|
||
|
@triton.jit
|
||
11 months ago
|
def _rmsnorm_kernel(
|
||
1 year ago
|
X, # pointer to the input
|
||
|
Y, # pointer to the output
|
||
|
W, # pointer to the weights
|
||
|
stride, # how much to increase the pointer when moving by 1 row
|
||
|
N, # number of columns in X
|
||
|
eps, # epsilon to avoid division by zero
|
||
1 year ago
|
BLOCK_SIZE: tl.constexpr,
|
||
|
):
|
||
11 months ago
|
|
||
|
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
|
||
|
|
||
1 year ago
|
# Map the program id to the row of X and Y it should compute.
|
||
|
row = tl.program_id(0)
|
||
|
Y += row * stride
|
||
|
X += row * stride
|
||
|
# Compute variance
|
||
|
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||
|
for off in range(0, N, BLOCK_SIZE):
|
||
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||
1 year ago
|
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||
11 months ago
|
x = tl.where(cols < N, x, 0.0)
|
||
1 year ago
|
_var += x * x
|
||
|
var = tl.sum(_var, axis=0) / N
|
||
|
rstd = 1 / tl.sqrt(var + eps)
|
||
|
# Normalize and apply linear transformation
|
||
|
for off in range(0, N, BLOCK_SIZE):
|
||
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||
|
mask = cols < N
|
||
|
w = tl.load(W + cols, mask=mask)
|
||
1 year ago
|
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
||
11 months ago
|
x_hat = x * rstd
|
||
|
y = x_hat * w
|
||
1 year ago
|
# Write output
|
||
|
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||
|
|
||
|
@torch.no_grad()
|
||
11 months ago
|
def rms_layernorm(x, weight, eps):
|
||
1 year ago
|
# allocate output
|
||
|
y = torch.empty_like(x)
|
||
|
# reshape input data into 2D tensor
|
||
|
x_arg = x.reshape(-1, x.shape[-1])
|
||
|
M, N = x_arg.shape
|
||
|
# Less than 64KB per feature: enqueue fused kernel
|
||
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||
|
if N > BLOCK_SIZE:
|
||
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||
|
# heuristics for number of warps
|
||
|
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||
|
# enqueue kernel
|
||
11 months ago
|
_rmsnorm_kernel[(M,)](
|
||
|
x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||
1 year ago
|
)
|
||
1 year ago
|
return y
|