mirror of https://github.com/hpcaitech/ColossalAI
[kernel] more flexible flashatt interface (#1804)
parent
20e255d4e8
commit
9639ea88fc
|
@ -11,7 +11,7 @@ import subprocess
|
|||
import torch
|
||||
|
||||
|
||||
def triton_check():
|
||||
def triton_cuda_check():
|
||||
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
|
||||
cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
|
||||
cuda_version = cuda_version.split('release ')[1]
|
||||
|
@ -27,7 +27,7 @@ def triton_check():
|
|||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
if triton_check():
|
||||
if triton_cuda_check():
|
||||
HAS_TRITON = True
|
||||
else:
|
||||
print("triton requires cuda >= 11.4")
|
||||
|
@ -36,7 +36,11 @@ except ImportError:
|
|||
print('please install triton from https://github.com/openai/triton')
|
||||
HAS_TRITON = False
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_func,
|
||||
flash_attn_unpadded_kvpacked_func,
|
||||
flash_attn_unpadded_qkvpacked_func,
|
||||
)
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
|
@ -405,12 +409,63 @@ if HAS_TRITON:
|
|||
|
||||
if HAS_FLASH_ATTN:
|
||||
|
||||
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
|
||||
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
qkv: (batch * seqlen, 3, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
dropout_p: float.
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
max_s = seq_len
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
out = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_s, dropout_p,
|
||||
softmax_scale=sm_scale, causal=causal
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch * q_seqlen, nheads, headdim)
|
||||
kv: (batch * kv_seqlen, 2, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
dropout_p: float.
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
|
||||
out = flash_attn_unpadded_kvpacked_func(q,
|
||||
kv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
dropout_p,
|
||||
sm_scale,
|
||||
causal)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch * q_seqlen, nheads, headdim)
|
||||
k: (batch * kv_seqlen, nheads, headdim)
|
||||
v: (batch * kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
|
@ -420,16 +475,15 @@ if HAS_FLASH_ATTN:
|
|||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
|
||||
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device)
|
||||
return flash_attn_unpadded_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=sm_scale,
|
||||
causal=causal)
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_kv,
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
dropout_p,
|
||||
sm_scale,
|
||||
causal)
|
||||
|
|
|
@ -5,7 +5,8 @@ from einops import rearrange
|
|||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import flash_attention
|
||||
from colossalai.kernel.cuda_native.flash_attention import (
|
||||
flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv)
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
||||
|
@ -22,8 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
|||
return ref_out
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||
@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
|
||||
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
|
@ -39,28 +40,20 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
# triton implementation
|
||||
if HAS_TRITON:
|
||||
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
else:
|
||||
try:
|
||||
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Error type not match!")
|
||||
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
|
||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
|
@ -78,15 +71,40 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||
|
||||
# flash implementation
|
||||
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
|
||||
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
|
||||
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
|
||||
tri_out.backward(dout, retain_graph=True)
|
||||
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk, tri_dv))
|
||||
for i in range(3):
|
||||
if i == 0:
|
||||
tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True)
|
||||
elif i == 1:
|
||||
kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1)
|
||||
tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True)
|
||||
else:
|
||||
qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1)
|
||||
tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True)
|
||||
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
tri_out.backward(dout, retain_graph=True)
|
||||
|
||||
if i == 0:
|
||||
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk, tri_dv))
|
||||
elif i == 1:
|
||||
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
|
||||
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
else:
|
||||
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
|
||||
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flash_attention(3, 4, 2, 16)
|
||||
|
|
Loading…
Reference in New Issue