mirror of https://github.com/hpcaitech/ColossalAI
71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
import warnings
|
|
|
|
HAS_MEM_EFF_ATTN = False
|
|
try:
|
|
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
|
|
from xformers.ops.fmha.attn_bias import (
|
|
BlockDiagonalCausalMask,
|
|
BlockDiagonalMask,
|
|
LowerTriangularMask,
|
|
LowerTriangularMaskWithTensorBias,
|
|
)
|
|
|
|
HAS_MEM_EFF_ATTN = True
|
|
except ImportError:
|
|
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers")
|
|
HAS_MEM_EFF_ATTN = False
|
|
|
|
if HAS_MEM_EFF_ATTN:
|
|
"""
|
|
A general attention module using the flash attention kernels from xformers:
|
|
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
|
|
"""
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from .utils import SeqLenInfo
|
|
|
|
allow_alibi = True
|
|
for op in MemoryEfficientAttentionCutlassOp:
|
|
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
|
|
|
|
def mem_eff_attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
seq_len_info_q: SeqLenInfo,
|
|
seq_len_info_kv: SeqLenInfo,
|
|
bias: Optional[torch.Tensor] = None,
|
|
dropout_p: float = 0.0,
|
|
scale: float = None,
|
|
causal: bool = False,
|
|
padded: bool = False,
|
|
):
|
|
attn_bias = None
|
|
if padded: # bert style
|
|
if not causal:
|
|
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
|
else:
|
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
|
elif causal: # gpt style
|
|
attn_bias = LowerTriangularMask()
|
|
|
|
if bias is not None: # alibi / relative position embedding
|
|
assert allow_alibi, "flash attention with bias is not supported in this system."
|
|
assert causal, "attention with bias is only supported for causal attention so far."
|
|
attn_bias = attn_bias.add_bias(bias)
|
|
|
|
if padded:
|
|
q = q.unsqueeze(0)
|
|
k = k.unsqueeze(0)
|
|
v = v.unsqueeze(0)
|
|
|
|
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
|
|
|
|
# shape: (b*s, n, d)
|
|
if padded:
|
|
out = out.squeeze(0)
|
|
|
|
return out
|