ColossalAI/colossalai/kernel/triton/llama_act_combine_kernel.py

187 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from functools import reduce
from typing import Any, Tuple
import torch
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
PRECISION_MAP = {
"fp32": (0, torch.float32),
"fp16": (1, torch.float16),
"bf16": (2, torch.bfloat16),
}
@triton.jit
def _llama_act_combine_forward(
X_GATE1,
X_GATE2,
X_UP,
Y,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
Y += row * stride
# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output
tl.store(Y + cols, y, mask=mask)
@triton.jit
def _llama_act_combine_backward(
X_GATE1,
X_GATE2,
X_UP,
X_GATE1_GRAD,
X_GATE2_GRAD,
X_UP_GRAD,
Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
X_GATE1_GRAD += row * stride
X_GATE2_GRAD += row * stride
X_UP_GRAD += row * stride
Y_GRAD += row * stride
# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
x_up_grad = x_gate2_act * x_gate1
x_gate1_grad = x_gate2_act * x_up
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 sigmoid(x)]
# = sigmoid(x) * {1 + x * [(1 sigmoid(x)]}
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))
# Write output
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)
class LlamaActCombine(torch.autograd.Function):
"""
act(x_gate) * x_up
Args:
x_gate (torch.Tensor): (b, l, 2d) x_gate
x_up (torch.Tensor): (b, l, d) x_up
activation (str): only support swiglu
precision (str): fp32, fp16, bf16
"""
@staticmethod
@custom_fwd
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
"""
act(x_gate) * x_up
Args:
x_gate (torch.Tensor): (b, l, 2d) x gate
x_up (torch.Tensor): (b, l, d) x up
activation (str): only support swiglu
"""
assert activation == "swiglu", "Only swiglu is supported"
# split x gate
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
x_gate1 = x_gate1.contiguous()
x_gate2 = x_gate2.contiguous()
if not x_up.is_contiguous():
x_up = x_up.contiguous()
# assert shape
assert x_gate1.shape == x_gate2.shape == x_up.shape
# add ctx for backward
if x_gate.requires_grad:
ctx.save_for_backward(x_gate1, x_gate2, x_up)
# allocate output
y = torch.empty_like(x_up)
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel
_llama_act_combine_forward[(M,)](
x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y
@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
# restore from ctx
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps
# init grad
y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = (
torch.empty_like(x_gate1),
torch.empty_like(x_gate2),
torch.empty_like(x_up),
)
# enqueue kernel
_llama_act_combine_backward[(M,)](
x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None