[kernel] more flexible flashatt interface (#1804)

pull/1809/head^2
oahzxl 2022-11-07 17:02:09 +08:00 committed by GitHub
parent 20e255d4e8
commit 9639ea88fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 49 deletions

View File

@ -11,7 +11,7 @@ import subprocess
import torch
def triton_check():
def triton_cuda_check():
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
cuda_version = cuda_version.split('release ')[1]
@ -27,7 +27,7 @@ def triton_check():
try:
import triton
import triton.language as tl
if triton_check():
if triton_cuda_check():
HAS_TRITON = True
else:
print("triton requires cuda >= 11.4")
@ -36,7 +36,11 @@ except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_func,
flash_attn_unpadded_kvpacked_func,
flash_attn_unpadded_qkvpacked_func,
)
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
@ -405,12 +409,63 @@ if HAS_TRITON:
if HAS_FLASH_ATTN:
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
qkv: (batch * seqlen, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, dropout_p,
softmax_scale=sm_scale, causal=causal
)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
kv: (batch * kv_seqlen, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q,
kv,
cu_seqlens_q,
cu_seqlens_k,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
causal)
return out
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
k: (batch * kv_seqlen, nheads, headdim)
v: (batch * kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
@ -420,16 +475,15 @@ if HAS_FLASH_ATTN:
Return:
out: (total, nheads, headdim).
"""
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device)
return flash_attn_unpadded_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
dropout_p=dropout_p,
softmax_scale=sm_scale,
causal=causal)
cu_seqlens_q,
cu_seqlens_kv,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
causal)

View File

@ -5,7 +5,8 @@ from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import flash_attention
from colossalai.kernel.cuda_native.flash_attention import (
flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv)
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
@ -22,8 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
@ -39,28 +40,20 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
if HAS_TRITON:
tri_out = triton_flash_attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
else:
try:
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
except RuntimeError:
pass
else:
raise TypeError("Error type not match!")
tri_out = triton_flash_attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
@ -78,15 +71,40 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# flash implementation
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
tri_out.backward(dout, retain_graph=True)
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
for i in range(3):
if i == 0:
tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True)
elif i == 1:
kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1)
tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True)
else:
qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1)
tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True)
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
tri_out.backward(dout, retain_graph=True)
if i == 0:
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
elif i == 1:
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
else:
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
if __name__ == '__main__':
test_flash_attention(3, 4, 2, 16)