mirror of https://github.com/hpcaitech/ColossalAI
updated attention kernel (#2133)
parent
484fe62252
commit
077a66dd81
|
@ -48,6 +48,13 @@ except ImportError:
|
||||||
HAS_FLASH_ATTN = False
|
HAS_FLASH_ATTN = False
|
||||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
HAS_MEM_EFF_ATTN = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_MEM_EFF_ATTN = False
|
||||||
|
print('please install xformers from https://github.com/facebookresearch/xformers')
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@ -497,3 +504,22 @@ if HAS_FLASH_ATTN:
|
||||||
device=k.device)
|
device=k.device)
|
||||||
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
|
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
|
||||||
causal)
|
causal)
|
||||||
|
|
||||||
|
|
||||||
|
if HAS_MEM_EFF_ATTN:
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from xformers.ops.fmha import LowerTriangularMask
|
||||||
|
|
||||||
|
class MemoryEfficientAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
attention_head_size = hidden_size // num_attention_heads
|
||||||
|
self.scale = 1 / attention_head_size**0.5
|
||||||
|
self.dropout = attention_dropout
|
||||||
|
|
||||||
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
|
||||||
|
context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
|
||||||
|
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||||
|
return context
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
|
||||||
|
|
||||||
if HAS_FLASH_ATTN:
|
if HAS_FLASH_ATTN:
|
||||||
from colossalai.kernel.cuda_native.flash_attention import (
|
from colossalai.kernel.cuda_native.flash_attention import (
|
||||||
|
@ -15,6 +15,9 @@ if HAS_FLASH_ATTN:
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
||||||
|
|
||||||
|
if HAS_MEM_EFF_ATTN:
|
||||||
|
from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention
|
||||||
|
|
||||||
|
|
||||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||||
|
@ -124,5 +127,20 @@ def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||||
out.backward(dout)
|
out.backward(dout)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||||
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
|
||||||
|
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||||
|
attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)
|
||||||
|
|
||||||
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||||
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||||
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||||
|
|
||||||
|
out = attn(q, k, v, attention_mask=LowerTriangularMask())
|
||||||
|
|
||||||
|
dout = torch.rand_like(out)
|
||||||
|
out.backward(dout)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_flash_attention(3, 4, 2, 16)
|
test_flash_attention(3, 4, 2, 16)
|
||||||
|
|
Loading…
Reference in New Issue