2024-01-25 09:01:48 +00:00
from . . base_extension import _Extension
class FlashAttentionXformersCudaExtension ( _Extension ) :
def __init__ ( self ) :
super ( ) . __init__ ( name = " flash_attention_xformers_cuda " , support_aot = False , support_jit = False )
def is_hardware_available ( self ) - > bool :
2024-01-30 01:55:16 +00:00
# cuda extension can only be built if cuda is available
2024-01-25 09:01:48 +00:00
try :
import torch
cuda_available = torch . cuda . is_available ( )
except :
cuda_available = False
return cuda_available
def assert_hardware_compatible ( self ) - > bool :
pass
def build_aot ( self ) - > None :
raise NotImplementedError (
" We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme. "
)
def build_jit ( self ) - > None :
raise NotImplementedError (
" We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme. "
)
def load ( self ) :
try :
from xformers . ops . fmha import MemoryEfficientAttentionCutlassOp , memory_efficient_attention
from xformers . ops . fmha . attn_bias import (
BlockDiagonalCausalMask ,
BlockDiagonalMask ,
LowerTriangularMask ,
LowerTriangularMaskWithTensorBias ,
)
except ImportError :
raise ModuleNotFoundError (
(
" We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme. "
)
)
from typing import Optional
import torch
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp :
allow_alibi = allow_alibi & ( LowerTriangularMaskWithTensorBias in op . SUPPORTED_ATTN_BIAS_TYPES )
def mem_eff_attention (
q : torch . Tensor ,
k : torch . Tensor ,
v : torch . Tensor ,
seq_len_info_q : " SeqLenInfo " ,
seq_len_info_kv : " SeqLenInfo " ,
origin_attn_mask : Optional [ torch . Tensor ] = None ,
bias : Optional [ torch . Tensor ] = None ,
dropout_p : float = 0.0 ,
scale : float = None ,
causal : bool = False ,
padded : bool = False ,
) :
attn_bias = None
if padded : # bert style
if not causal :
attn_bias = BlockDiagonalMask . from_seqlens ( seq_len_info_q . seqlens , seq_len_info_kv . seqlens )
else :
attn_bias = BlockDiagonalCausalMask . from_seqlens ( seq_len_info_q . seqlens , seq_len_info_kv . seqlens )
elif causal : # gpt style
attn_bias = LowerTriangularMask ( )
if bias is not None : # alibi / relative position embedding
assert allow_alibi , " flash attention with bias is not supported in this system. "
assert causal , " attention with bias is only supported for causal attention so far. "
attn_bias = attn_bias . add_bias ( bias )
if padded :
q = q . unsqueeze ( 0 )
k = k . unsqueeze ( 0 )
v = v . unsqueeze ( 0 )
out = memory_efficient_attention ( q , k , v , attn_bias = attn_bias , p = dropout_p , scale = scale )
# shape: (b*s, n, d)
if padded :
out = out . squeeze ( 0 )
return out
return mem_eff_attention