from functools import reduce from typing import Any, Tuple import torch from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd try: import triton import triton.language as tl HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") if HAS_TRITON: PRECISION_MAP = { "fp32": (0, torch.float32), "fp16": (1, torch.float16), "bf16": (2, torch.bfloat16), } @triton.jit def _llama_act_combine_forward( X_GATE1, X_GATE2, X_UP, Y, stride, # how much to increase the pointer when moving by 1 row N, # number of columns in X BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X_GATE1 += row * stride X_GATE2 += row * stride X_UP += row * stride Y += row * stride # do activation and combine, and store in y for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0) x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0) x_up = tl.load(X_UP + cols, mask=mask, other=0.0) x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up # Write output tl.store(Y + cols, y, mask=mask) @triton.jit def _llama_act_combine_backward( X_GATE1, X_GATE2, X_UP, X_GATE1_GRAD, X_GATE2_GRAD, X_UP_GRAD, Y_GRAD, stride, # how much to increase the pointer when moving by 1 row N, # number of columns in X BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X_GATE1 += row * stride X_GATE2 += row * stride X_UP += row * stride X_GATE1_GRAD += row * stride X_GATE2_GRAD += row * stride X_UP_GRAD += row * stride Y_GRAD += row * stride # do activation and combine, and store in y for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0) x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0) x_up = tl.load(X_UP + cols, mask=mask, other=0.0) y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0) # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid x_up_grad = x_gate2_act * x_gate1 x_gate1_grad = x_gate2_act * x_up # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)] # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]} x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid)) # Write output tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask) tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask) tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask) class LlamaActCombine(torch.autograd.Function): """ act(x_gate) * x_up Args: x_gate (torch.Tensor): (b, l, 2d) x_gate x_up (torch.Tensor): (b, l, d) x_up activation (str): only support swiglu precision (str): fp32, fp16, bf16 """ @staticmethod @custom_fwd def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor: """ act(x_gate) * x_up Args: x_gate (torch.Tensor): (b, l, 2d) x gate x_up (torch.Tensor): (b, l, d) x up activation (str): only support swiglu """ assert activation == "swiglu", "Only swiglu is supported" # split x gate assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1) x_gate1 = x_gate1.contiguous() x_gate2 = x_gate2.contiguous() if not x_up.is_contiguous(): x_up = x_up.contiguous() # assert shape assert x_gate1.shape == x_gate2.shape == x_up.shape # add ctx for backward if x_gate.requires_grad: ctx.save_for_backward(x_gate1, x_gate2, x_up) # allocate output y = torch.empty_like(x_up) M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x_gate.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # restore setting ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps # enqueue kernel _llama_act_combine_forward[(M,)]( x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps ) return y @staticmethod @custom_bwd def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]: # restore from ctx (x_gate1, x_gate2, x_up) = ctx.saved_tensors M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps # init grad y_grad = grad_outputs[0] x_gate1_grad, x_gate2_grad, x_up_grad = ( torch.empty_like(x_gate1), torch.empty_like(x_gate2), torch.empty_like(x_up), ) # enqueue kernel _llama_act_combine_backward[(M,)]( x_gate1, x_gate2, x_up, x_gate1_grad, x_gate2_grad, x_up_grad, y_grad, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) return x_gate_grad, x_up_grad, None, None