[kernel] Add triton kernel for context attention (FAv2) without padding (#5192)

* add context attn unpadded triton kernel

* test compatibility

* kv cache copy (testing)

* fix k/v cache copy

* fix kv cache copy and test

* fix boundary of block ptrs

* add support for GQA/MQA and testing

* fix import statement

---------

Co-authored-by: Round Heng <yuanhengzhao@Rounds-MacBook-Pro.local>
pull/5258/head
Yuanheng Zhao 11 months ago committed by FrankLeeeee
parent 4df8876fca
commit 07b5283b6a

@ -8,11 +8,13 @@ except ImportError:
# There may exist import error even if we have triton installed. # There may exist import error even if we have triton installed.
if HAS_TRITON: if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton from .gptq_triton import gptq_fused_linear_triton
from .softmax import softmax from .softmax import softmax
__all__ = [ __all__ = [
"context_attention_unpadded",
"softmax", "softmax",
"layer_norm", "layer_norm",
"gptq_fused_linear_triton", "gptq_fused_linear_triton",

@ -0,0 +1,262 @@
# Applying the FlashAttention V2 as described in:
# "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"
# by Tri Dao, 2023
# https://github.com/Dao-AILab/flash-attention
#
# Inspired and modified from Triton Tutorial - Fused Attention
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
import torch
import triton
import triton.language as tl
# Triton 2.1.0
@triton.jit
def _fwd_context_paged_attention_kernel(
Q,
K,
V,
O,
KCache,
VCache,
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
stride_qt,
stride_qh,
stride_qd,
stride_kt,
stride_kh,
stride_kd,
stride_vt,
stride_vh,
stride_vd,
stride_ot,
stride_oh,
stride_od,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_bts,
stride_btb,
context_lengths,
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
cur_head_idx = tl.program_id(1)
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
cur_kv_head_idx = cur_head_idx // KV_GROUPS
# NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same
tl.static_assert(BLOCK_M == BLOCK_N)
tl.static_assert(BLOCK_N == BLOCK_SIZE)
# get the current sequence length from provided context lengths tensor
cur_seq_len = tl.load(context_lengths + cur_seq_idx)
# NOTE when talking to fused QKV and a nopadding context attention,
# we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum`
# could be considered as the start index of the current sequence.
# FIXME might want to explore better way to get the summation of prev seq lengths.
# `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton.
prev_seq_len_sum = 0
for i in range(0, cur_seq_idx):
prev_seq_len_sum += tl.load(context_lengths + i)
q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(cur_seq_len, BLOCK_DMODEL),
strides=(stride_qt, stride_qd),
offsets=(block_start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + kv_offset,
shape=(BLOCK_DMODEL, cur_seq_len),
strides=(stride_kd, stride_kt),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + kv_offset,
shape=(cur_seq_len, BLOCK_DMODEL),
strides=(stride_vt, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
O_block_ptr = tl.make_block_ptr(
base=O + q_offset,
shape=(cur_seq_len, BLOCK_DMODEL),
strides=(stride_ot, stride_od),
offsets=(block_start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# block table for the current sequence
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
# block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq)
# Consider `block_start_m` as the logical block idx in the current block table,
# as we have BLOCK_M the same size as the block size.
cur_block_table_idx = block_start_m
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_n = tl.arange(0, BLOCK_N)
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if block_start_m * BLOCK_M >= cur_seq_len:
return
Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0))
for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N):
block_start_n = tl.multiple_of(block_start_n, BLOCK_N)
k = tl.load(K_block_ptr, boundary_check=(0, 1))
S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
S_ij += tl.dot(Q_i, k)
S_ij *= sm_scale
S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf"))
m_ij = tl.max(S_ij, 1) # rowmax(Sij)
m_ij = tl.maximum(m_i, m_ij) # m_ij
S_ij -= m_ij[:, None]
p_ij_hat = tl.exp(S_ij)
scale = tl.exp(m_i - m_ij)
l_ij = scale * l_i + tl.sum(p_ij_hat, 1)
acc = acc * scale[:, None]
v = tl.load(V_block_ptr, boundary_check=(1, 0))
p_ij_hat = p_ij_hat.to(v.type.element_ty)
acc += tl.dot(p_ij_hat, v)
l_i = l_ij
m_i = m_ij
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0))
if cur_head_idx % KV_GROUPS == 0:
# Copy k to corresponding cache block
kd_offsets = tl.arange(0, BLOCK_DMODEL)
kt_offsets = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
k_offsets = K + kv_offset + kd_offsets[:, None] * stride_kd + kt_offsets[None, :] * stride_kt
k = tl.load(k_offsets, mask=kt_offsets[None, :] < cur_seq_len, other=0.0)
kcached_offsets = tl.arange(0, BLOCK_DMODEL)
kcachebs_offsets = tl.arange(0, BLOCK_SIZE)
kcache_offsets = (
KCache
+ kvcache_offset
+ kcached_offsets[:, None] * stride_cached
+ kcachebs_offsets[None, :] * stride_cachebs
)
tl.store(kcache_offsets, k, mask=kcachebs_offsets[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
vd_offsets = kd_offsets
vt_offsets = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
v_offsets = V + kv_offset + vt_offsets[:, None] * stride_vt + vd_offsets[None, :] * stride_vd
v = tl.load(v_offsets, mask=vt_offsets[:, None] < cur_seq_len, other=0.0)
vcached_offsets = kcached_offsets
vcachebs_offsets = kcachebs_offsets
vcache_offsets = (
VCache
+ kvcache_offset
+ vcachebs_offsets[:, None] * stride_cachebs
+ vcached_offsets[None, :] * stride_cached
)
tl.store(vcache_offsets, v, mask=vcachebs_offsets[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
return
def context_attention_unpadded(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
):
# q/k in context stage are supposed to be put into k_cache and v_cache.
# This step can be optimized in future.
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk == Lv
assert Lk in {32, 64, 128, 256}
assert q.shape[0] == k.shape[0] == v.shape[0]
assert k_cache.shape == v_cache.shape
assert context_lengths.shape[0] == block_tables.shape[0]
num_tokens, num_heads, _ = q.shape
num_kv_heads = k.shape[-2]
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
num_kv_group = num_heads // num_kv_heads
num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item()
sm_scale = 1.0 / (Lq**0.5)
output = torch.zeros_like(q)
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`)
assert block_size in {16, 32, 64, 128}
BLOCK_M = BLOCK_N = block_size
grid = (num_seqs, num_heads, triton.cdiv(max_seq_len, BLOCK_M))
_fwd_context_paged_attention_kernel[grid](
q,
k,
v,
output,
k_cache,
v_cache,
block_tables,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
context_lengths,
sm_scale,
num_kv_group,
block_size,
BLOCK_DMODEL=Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
return output

@ -0,0 +1,158 @@
import pytest
import torch
import torch.nn.functional as F
from packaging import version
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_attn_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_len: int, num_heads: int, head_size: int):
# For a single sequence, q,k,v [seq_len, num_heads, head_size]
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_size
q = q.view(seq_len, num_heads, head_size)
k = k.view(seq_len, num_heads, head_size)
v = v.view(seq_len, num_heads, head_size)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
mask = torch.tril(torch.ones(1, seq_len, seq_len), diagonal=0).to(device=get_current_device())
mask[mask == 0.0] = float("-inf")
mask = mask.repeat(num_heads, 1, 1)
qk = torch.matmul(q, k.transpose(1, 2))
attn_scores = qk / (head_size**0.5)
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32) + mask, dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v).transpose(0, 1).contiguous()
out = out.reshape(-1, num_heads, head_size)
return out
def torch_attn_unpad(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor):
# Process sequence one by one and cat them together.
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_size]
assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor"
_, num_heads, head_size = q.shape
out_torch = []
start_idx = 0
for i in range(len(context_lengths)):
end_idx = start_idx + context_lengths[i].item()
torch_attn_ref_out = torch_attn_ref(
q[start_idx:end_idx], k[start_idx:end_idx], v[start_idx:end_idx], end_idx - start_idx, num_heads, head_size
)
out_torch.append(torch_attn_ref_out)
start_idx = end_idx
return torch.cat(out_torch, dim=0)
# This method is adapted from src/transformers/models/llama/modeling_llama.py
# in transformers repository https://github.com/huggingface/transformers
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (num_tokens,
num_key_value_heads, head_dim) to (num_tokens, num_attention_heads, head_dim)
"""
num_tokens, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :].expand(num_tokens, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(num_tokens, num_key_value_heads * n_rep, head_dim)
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_context_attention(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_attn_heads: int,
kv_group_num: int,
same_context_len: bool,
):
torch.manual_seed(123)
dtype = torch.float16
device = get_current_device()
num_seqs = bsz
num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
head_size = 32
max_seq_len = max_num_blocks_per_seq * block_size
# It's necessary to clear cache here.
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(num_seqs)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(num_seqs,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_size)
qkv = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q, k, v = torch.split(qkv, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_size, block_size)
k_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device)
k_cache_triton = torch.zeros_like(k_cache_torch)
v_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache_triton = torch.zeros_like(v_cache_torch)
# Mock allocation on block tables
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill k_cache_torch and v_cache_torch by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
cur_block_size_occupied = k_block.shape[-1]
assert cur_block_size_occupied <= block_size, "Invalid occupied size of block during mock allocation"
k_cache_torch[block_id, :, :, :cur_block_size_occupied] = k_block
v_cache_torch[block_id, :, :, :cur_block_size_occupied] = v_block
num_tokens_processed += allocated_locs
block_id += 1
block_tables = block_tables.to(device=device)
out_triton = context_attention_unpadded(
q, k, v, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
)
# For GQA and MQA, repeat k, v for torch attention calculation
# k/v won't change if provided `num_kv_group` is 1
num_kv_group = num_attn_heads // num_kv_heads
k = repeat_kv(k, num_kv_group)
v = repeat_kv(v, num_kv_group)
out_torch = torch_attn_unpad(q, k, v, context_lengths)
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-2, rtol=1e-3)
assert torch.allclose(k_cache_torch, k_cache_triton)
assert torch.allclose(v_cache_torch, v_cache_triton)
Loading…
Cancel
Save