[kernel] skip tests of flash_attn and triton when they are not available (#1798)

pull/1802/head
Jiarui Fang 2022-11-07 13:41:13 +08:00 committed by GitHub
parent e34e850a4c
commit c248800359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 412 additions and 301 deletions

View File

@ -61,7 +61,7 @@ class GeminiManager:
self._comp_cuda_demand_time = 0 self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of statefuil tensor according to the information provided """ Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model. by mem_stats_collector, which should belongs to a Sharded Model.
""" """
# find stateful tensor in state COMPUTE # find stateful tensor in state COMPUTE

View File

@ -5,20 +5,24 @@ This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
""" """
import torch
import subprocess
import os import os
import subprocess
import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True
except ImportError: except ImportError:
raise ImportError('please install triton from https://github.com/openai/triton') print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try: try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
HAS_FLASH_ATTN = True
except ImportError: except ImportError:
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention') HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def triton_check(): def triton_check():
@ -33,299 +37,396 @@ def triton_check():
return True return True
return False return False
TRITON_AVALIABLE = triton_check() TRITON_AVALIABLE = triton_check()
if TRITON_AVALIABLE:
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, K, V, sm_scale, Q,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug K,
Out, V,
stride_qz, stride_qh, stride_qm, stride_qk, sm_scale,
stride_kz, stride_kh, stride_kn, stride_kk, TMP,
stride_vz, stride_vh, stride_vk, stride_vn, L,
stride_oz, stride_oh, stride_om, stride_on, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Z, H, N_CTX, Out,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, stride_qz,
BLOCK_N: tl.constexpr, stride_qh,
): stride_qm,
start_m = tl.program_id(0) stride_qk,
off_hz = tl.program_id(1) stride_kz,
# initialize offsets stride_kh,
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) stride_kn,
offs_n = tl.arange(0, BLOCK_N) stride_kk,
offs_d = tl.arange(0, BLOCK_DMODEL) stride_vz,
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk stride_vh,
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk stride_vk,
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk stride_vn,
# Initialize pointers to Q, K, V stride_oz,
q_ptrs = Q + off_q stride_oh,
k_ptrs = K + off_k stride_om,
v_ptrs = V + off_v stride_on,
# initialize pointer to m and l Z,
t_ptrs = TMP + off_hz * N_CTX + offs_m H,
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") N_CTX,
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) BLOCK_M: tl.constexpr,
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) BLOCK_DMODEL: tl.constexpr,
# load q: it will stay in SRAM throughout BLOCK_N: tl.constexpr,
q = tl.load(q_ptrs) ):
# loop over k, v and update accumulator start_m = tl.program_id(0)
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): off_hz = tl.program_id(1)
start_n = tl.multiple_of(start_n, BLOCK_N) # initialize offsets
# -- compute qk ---- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
k = tl.load(k_ptrs + start_n * stride_kn) offs_n = tl.arange(0, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) offs_d = tl.arange(0, BLOCK_DMODEL)
qk += tl.dot(q, k, trans_b=True) off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
qk *= sm_scale off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# -- compute m_ij, p, l_ij # Initialize pointers to Q, K, V
m_ij = tl.max(qk, 1) q_ptrs = Q + off_q
p = tl.exp(qk - m_ij[:, None]) k_ptrs = K + off_k
l_ij = tl.sum(p, 1) v_ptrs = V + off_v
# -- update m_i and l_i # initialize pointer to m and l
m_i_new = tl.maximum(m_i, m_ij) t_ptrs = TMP + off_hz * N_CTX + offs_m
alpha = tl.exp(m_i - m_i_new) m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
beta = tl.exp(m_ij - m_i_new) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
l_i_new = alpha * l_i + beta * l_ij acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# -- update output accumulator -- # load q: it will stay in SRAM throughout
# scale p q = tl.load(q_ptrs)
p_scale = beta / l_i_new # loop over k, v and update accumulator
p = p * p_scale[:, None] for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# scale acc start_n = tl.multiple_of(start_n, BLOCK_N)
acc_scale = l_i / l_i_new * alpha # -- compute qk ----
tl.store(t_ptrs, acc_scale) k = tl.load(k_ptrs + start_n * stride_kn)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
acc = acc * acc_scale[:, None] qk += tl.dot(q, k, trans_b=True)
# update acc qk *= sm_scale
v = tl.load(v_ptrs + start_n * stride_vk) qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
p = p.to(tl.float16) # -- compute m_ij, p, l_ij
acc += tl.dot(p, v) m_ij = tl.max(qk, 1)
# update m_i and l_i p = tl.exp(qk - m_ij[:, None])
l_i = l_i_new l_ij = tl.sum(p, 1)
m_i = m_i_new # -- update m_i and l_i
# rematerialize offsets to save registers m_i_new = tl.maximum(m_i, m_ij)
start_m = tl.program_id(0) alpha = tl.exp(m_i - m_i_new)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) beta = tl.exp(m_ij - m_i_new)
# write back l and m l_i_new = alpha * l_i + beta * l_ij
l_ptrs = L + off_hz * N_CTX + offs_m # -- update output accumulator --
m_ptrs = M + off_hz * N_CTX + offs_m # scale p
tl.store(l_ptrs, l_i) p_scale = beta / l_i_new
tl.store(m_ptrs, m_i) p = p * p_scale[:, None]
# initialize pointers to output # scale acc
offs_n = tl.arange(0, BLOCK_DMODEL) acc_scale = l_i / l_i_new * alpha
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(t_ptrs, acc_scale)
out_ptrs = Out + off_o acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
tl.store(out_ptrs, acc) acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
@triton.jit def _bwd_preprocess(
def _bwd_preprocess( Out,
Out, DO, L, DO,
NewDO, Delta, L,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, NewDO,
): Delta,
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) BLOCK_M: tl.constexpr,
off_n = tl.arange(0, D_HEAD) D_HEAD: tl.constexpr,
# load ):
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) off_n = tl.arange(0, D_HEAD)
denom = tl.load(L + off_m).to(tl.float32) # load
# compute o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = do / denom[:, None] do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1) denom = tl.load(L + off_m).to(tl.float32)
# write-back # compute
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) do = do / denom[:, None]
tl.store(Delta + off_m, delta) delta = tl.sum(o * do, axis=1)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back # write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) tl.store(Delta + off_m, delta)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk) @triton.jit
def _bwd_kernel(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _TritonFlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
tmp,
L,
m,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o,
do,
l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q,
k,
v,
ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK,
BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
def triton_flash_attention(q, k, v, sm_scale):
"""
Arguments:
q: (batch, nheads, seq, headdim)
k: (batch, nheads, seq, headdim)
v: (batch, nheads, seq, headdim)
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (batch, nheads, seq, headdim)
"""
if TRITON_AVALIABLE:
return _TritonFlashAttention.apply(q, k, v, sm_scale)
else:
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
class _TritonFlashAttention(torch.autograd.Function): if HAS_FLASH_ATTN:
@staticmethod def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
def forward(ctx, q, k, v, sm_scale): """
BLOCK = 128 Arguments:
# shape constraints q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
assert Lq == Lk and Lk == Lv v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
assert Lk in {16, 32, 64, 128} batch_size: int.
o = torch.empty_like(q) seq_len: int.
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) dropout_p: float. Dropout probability.
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) sm_scale: float. The scaling of QK^T before applying softmax.
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) Default to 1 / sqrt(headdim).
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
num_warps = 4 if Lk <= 64 else 8 Return:
out: (total, nheads, headdim).
_fwd_kernel[grid]( """
q, k, v, sm_scale, lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
tmp, L, m, cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
o, cu_seqlens[1:] = lengths.cumsum(0)
q.stride(0), q.stride(1), q.stride(2), q.stride(3), return flash_attn_unpadded_func(q,
k.stride(0), k.stride(1), k.stride(2), k.stride(3), k,
v.stride(0), v.stride(1), v.stride(2), v.stride(3), v,
o.stride(0), o.stride(1), o.stride(2), o.stride(3), cu_seqlens_q=cu_seqlens,
q.shape[0], q.shape[1], q.shape[2], cu_seqlens_k=cu_seqlens,
BLOCK_M=BLOCK, BLOCK_N=BLOCK, max_seqlen_q=seq_len,
BLOCK_DMODEL=Lk, num_warps=num_warps, max_seqlen_k=seq_len,
num_stages=1, dropout_p=dropout_p,
) softmax_scale=sm_scale,
ctx.save_for_backward(q, k, v, o, L, m) causal=causal)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
def triton_flash_attention(q, k, v, sm_scale):
"""
Arguments:
q: (batch, nheads, seq, headdim)
k: (batch, nheads, seq, headdim)
v: (batch, nheads, seq, headdim)
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (batch, nheads, seq, headdim)
"""
if TRITON_AVALIABLE:
return _TritonFlashAttention.apply(q, k, v, sm_scale)
else:
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
"""
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len,
dropout_p=dropout_p, softmax_scale=sm_scale, causal=causal)

View File

@ -1,7 +1,14 @@
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import flash_attention
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@ -14,7 +21,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
ref_out = torch.matmul(p, v) ref_out = torch.matmul(p, v)
return ref_out return ref_out
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20) torch.manual_seed(20)
@ -23,7 +31,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3 sm_scale = 0.3
dout = torch.randn_like(q) dout = torch.randn_like(q)
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
ref_out.backward(dout) ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None ref_dv, v.grad = v.grad.clone(), None
@ -51,6 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise TypeError("Error type not match!") raise TypeError("Error type not match!")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20) torch.manual_seed(20)
@ -59,21 +68,22 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3 sm_scale = 0.3
dout = torch.randn_like(q) dout = torch.randn_like(q)
# reference implementation # reference implementation
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
ref_out.backward(dout) ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None
# flash implementation # flash implementation
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v]) q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX) tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
dout = rearrange(dout, 'z h n d -> (z n) h d').detach() dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
tri_out.backward(dout, retain_graph=True) tri_out.backward(dout, retain_graph=True)
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv)) tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
# compare # compare
assert torch.allclose(ref_out, tri_out, atol=1e-3) assert torch.allclose(ref_out, tri_out, atol=1e-3)