2024-01-25 09:01:48 +00:00
from . . base_extension import _Extension
class FlashAttentionDaoCudaExtension ( _Extension ) :
def __init__ ( self ) :
super ( ) . __init__ ( name = " flash_attention_dao_cuda " , support_aot = False , support_jit = False , priority = 10 )
2024-03-27 03:19:32 +00:00
def is_available ( self ) - > bool :
2024-01-30 01:55:16 +00:00
# cuda extension can only be built if cuda is available
2024-01-25 09:01:48 +00:00
try :
import torch
2024-03-27 03:19:32 +00:00
from flash_attn import flash_attn_func , flash_attn_varlen_kvpacked_func # noqa
from flash_attn . bert_padding import index_first_axis , pad_input # noqa
2024-01-25 09:01:48 +00:00
cuda_available = torch . cuda . is_available ( )
except :
cuda_available = False
return cuda_available
2024-03-27 03:19:32 +00:00
def assert_compatible ( self ) - > bool :
2024-01-25 09:01:48 +00:00
pass
def build_aot ( self ) - > None :
raise NotImplementedError (
" We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via ' pip install flash-attn --no-build-isolation ' . "
)
def build_jit ( self ) - > None :
raise NotImplementedError (
" We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via ' pip install flash-attn --no-build-isolation ' "
)
def load ( self ) :
from typing import Optional
import torch
2024-03-27 03:19:32 +00:00
from einops import rearrange
from flash_attn import flash_attn_func , flash_attn_varlen_kvpacked_func
from flash_attn . bert_padding import index_first_axis , pad_input
def _unpad_input ( hidden_states : torch . Tensor , indices : torch . Tensor ) :
return index_first_axis ( rearrange ( hidden_states , " b s ... -> (b s) ... " ) , indices )
2024-01-25 09:01:48 +00:00
def flash_attention (
q : torch . Tensor ,
k : torch . Tensor ,
v : torch . Tensor ,
dropout_p : float = 0.0 ,
2024-03-27 03:19:32 +00:00
scale : Optional [ float ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
is_causal : bool = False ,
cu_seqlens_q : Optional [ torch . Tensor ] = None ,
cu_seqlens_kv : Optional [ torch . Tensor ] = None ,
max_seqlen_q : Optional [ int ] = None ,
max_seqlen_kv : Optional [ int ] = None ,
q_indices : Optional [ torch . Tensor ] = None ,
kv_indices : Optional [ torch . Tensor ] = None ,
2024-01-25 09:01:48 +00:00
) :
2024-03-27 03:19:32 +00:00
# [B, N, S, D] -> [B, S, N, D]
q = q . transpose ( 1 , 2 )
k = k . transpose ( 1 , 2 )
v = v . transpose ( 1 , 2 )
b , s_q = q . shape [ : 2 ]
if cu_seqlens_q is not None :
# padded / padded causal
# unpad input: [B, S, N, D] -> [T, N, D]
q = _unpad_input ( q , q_indices )
kv = _unpad_input ( torch . stack ( tensors = ( k , v ) , dim = 2 ) , kv_indices )
attn_output = flash_attn_varlen_kvpacked_func (
q ,
kv ,
cu_seqlens_q ,
cu_seqlens_kv ,
max_seqlen_q ,
max_seqlen_kv ,
dropout_p = dropout_p ,
softmax_scale = scale ,
causal = is_causal ,
)
# pad output: [T, N, D] -> [B, S, N, D]
attn_output = pad_input ( attn_output , q_indices , b , s_q )
else :
# causal / no attn mask
attn_output = flash_attn_func (
2024-01-25 09:01:48 +00:00
q ,
k ,
v ,
2024-03-27 03:19:32 +00:00
dropout_p = dropout_p ,
softmax_scale = scale ,
causal = is_causal ,
2024-01-25 09:01:48 +00:00
)
2024-03-27 03:19:32 +00:00
# [B, S, N, D] -> [B, N, S, D]
return attn_output . transpose ( 1 , 2 )
2024-01-25 09:01:48 +00:00
return flash_attention