You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/extensions/flash_attention/flash_attention_xformers_cu...

95 lines
3.5 KiB

from ..base_extension import _Extension
class FlashAttentionXformersCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def build_jit(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def load(self):
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
)
from typing import Optional
import torch
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
return mem_eff_attention