2024-03-27 03:19:32 +00:00
|
|
|
from enum import Enum
|
|
|
|
from typing import Callable, Dict, Optional, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from colossalai.kernel.kernel_loader import (
|
|
|
|
FlashAttentionForFloatAndCustomMaskLoader,
|
|
|
|
FlashAttentionLoader,
|
|
|
|
FlashAttentionWithCustomMaskLoader,
|
|
|
|
KernelLoader,
|
|
|
|
)
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"AttnMaskType",
|
|
|
|
"ColoAttention",
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class AttnMaskType(Enum):
|
|
|
|
CUSTOM = 0
|
|
|
|
PADDED = 1
|
|
|
|
CAUSAL = 2
|
|
|
|
PADDED_CAUSAL = 3
|
|
|
|
|
|
|
|
|
|
|
|
def invert_mask(mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
"""Invert the mask tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Inverted mask tensor.
|
|
|
|
"""
|
|
|
|
inverted_mask = 1.0 - mask
|
|
|
|
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
|
|
|
|
|
|
|
|
|
|
|
|
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
|
|
|
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
|
|
|
"""Get padding information from padding mask.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
|
|
|
|
"""
|
|
|
|
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
|
return max_seqlen_in_batch, cu_seqlens, indices
|
|
|
|
|
|
|
|
|
|
|
|
class ColoAttention:
|
|
|
|
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _init_kernels_dispatch():
|
|
|
|
if ColoAttention._kernel_dispatch_map is None:
|
|
|
|
# fp16/bf16
|
|
|
|
half_dispatch_map = {
|
|
|
|
None: FlashAttentionLoader(),
|
|
|
|
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
2024-04-25 02:47:14 +00:00
|
|
|
AttnMaskType.PADDED: FlashAttentionLoader(),
|
2024-03-27 03:19:32 +00:00
|
|
|
AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
2024-04-25 02:47:14 +00:00
|
|
|
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
|
2024-03-27 03:19:32 +00:00
|
|
|
}
|
|
|
|
# fp32
|
|
|
|
float_dispatch_map = {
|
|
|
|
None: FlashAttentionForFloatAndCustomMaskLoader(),
|
|
|
|
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
2024-04-25 02:47:14 +00:00
|
|
|
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
|
2024-03-27 03:19:32 +00:00
|
|
|
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
2024-04-25 02:47:14 +00:00
|
|
|
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
2024-03-27 03:19:32 +00:00
|
|
|
}
|
|
|
|
ColoAttention._kernel_dispatch_map = {
|
|
|
|
torch.float16: half_dispatch_map,
|
|
|
|
torch.bfloat16: half_dispatch_map,
|
|
|
|
torch.float32: float_dispatch_map,
|
|
|
|
}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
|
|
|
|
ColoAttention._init_kernels_dispatch()
|
|
|
|
if (
|
|
|
|
dtype not in ColoAttention._kernel_dispatch_map
|
|
|
|
or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
|
|
|
)
|
|
|
|
# lazy load
|
|
|
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
|
|
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
|
|
|
mask_type
|
|
|
|
].load()
|
|
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def prepare_attn_kwargs(
|
|
|
|
shape_4d: Tuple[int],
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: torch.device,
|
|
|
|
q_padding_mask: Optional[torch.Tensor] = None,
|
|
|
|
kv_padding_mask: Optional[torch.Tensor] = None,
|
|
|
|
is_causal: bool = False,
|
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
|
|
|
|
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
|
|
|
|
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
|
|
|
|
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
|
|
|
|
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
|
|
|
|
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
|
|
|
|
device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
|
|
|
|
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
|
|
|
|
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
|
|
|
|
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
|
|
|
|
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
|
|
|
|
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
|
|
|
|
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
|
|
|
|
"""
|
|
|
|
if q_padding_mask is None and not is_causal:
|
|
|
|
return {}
|
|
|
|
assert len(shape_4d) == 4 and shape_4d[1] == 1
|
|
|
|
b, _, s_q, s_kv = shape_4d
|
|
|
|
outputs = {}
|
|
|
|
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
|
|
|
|
kv_padding_mask is None or kv_padding_mask.bool().all()
|
|
|
|
):
|
|
|
|
# no padding
|
|
|
|
assert is_causal
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
|
|
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
|
|
|
else:
|
2024-04-25 02:47:14 +00:00
|
|
|
assert q_padding_mask.shape == (
|
|
|
|
b,
|
|
|
|
s_q,
|
|
|
|
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
|
|
|
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
2024-03-27 03:19:32 +00:00
|
|
|
if kv_padding_mask is None:
|
|
|
|
# self attention
|
|
|
|
kv_padding_mask = q_padding_mask
|
2024-04-25 02:47:14 +00:00
|
|
|
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
|
|
|
|
else:
|
|
|
|
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
|
|
|
assert kv_padding_mask.shape == (
|
2024-03-27 03:19:32 +00:00
|
|
|
b,
|
|
|
|
s_kv,
|
2024-04-25 02:47:14 +00:00
|
|
|
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
|
|
|
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
|
2024-03-27 03:19:32 +00:00
|
|
|
outputs.update(
|
|
|
|
{
|
|
|
|
"cu_seqlens_q": cu_seqlens_q,
|
|
|
|
"cu_seqlens_kv": cu_seqlens_kv,
|
|
|
|
"max_seqlen_q": max_seqlen_q,
|
|
|
|
"max_seqlen_kv": max_seqlen_kv,
|
|
|
|
"q_indices": q_indices,
|
|
|
|
"kv_indices": kv_indices,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
if is_causal:
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
|
|
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
|
|
|
else:
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
|
|
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
|
|
|
outputs["attention_mask"] = attention_mask
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def attention(
|
|
|
|
q: torch.Tensor,
|
|
|
|
k: torch.Tensor,
|
|
|
|
v: torch.Tensor,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
|
|
|
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
|
|
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
|
|
|
max_seqlen_q: Optional[int] = None,
|
|
|
|
max_seqlen_kv: Optional[int] = None,
|
|
|
|
q_indices: Optional[torch.Tensor] = None,
|
|
|
|
kv_indices: Optional[torch.Tensor] = None,
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
scale: Optional[float] = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""Flash Attention function. It supports 4 mask type.
|
|
|
|
1. custom mask: recv attention_mask
|
|
|
|
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
|
|
|
3. causal mask: recv attention_mask, attention_mask_type
|
|
|
|
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
|
|
|
|
|
|
|
Args:
|
|
|
|
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
|
|
|
|
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
|
|
|
|
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
|
|
|
|
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
|
|
|
|
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
|
|
|
|
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
|
|
|
of the sequences in the batch, used to index into q.
|
|
|
|
Shape should be [B+1]. Defaults to None.
|
|
|
|
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
|
|
|
of the sequences in the batch, used to index into kv.
|
|
|
|
Shape should be [B+1]. Defaults to None.
|
|
|
|
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
|
|
|
|
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
|
|
|
|
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
|
|
|
|
Shape should be [NUM_TOKENS]. Defaults to None.
|
|
|
|
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
|
|
|
|
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
|
|
|
|
"""
|
|
|
|
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
|
|
|
|
# this case is usaul when padding mask is used and self attention is performed
|
|
|
|
# thus, we don't use sdpa when padding mask is used
|
|
|
|
# sanity check
|
|
|
|
if attention_mask is not None:
|
|
|
|
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
|
|
|
|
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
|
|
|
|
assert (
|
|
|
|
cu_seqlens_q is None
|
|
|
|
and cu_seqlens_kv is None
|
|
|
|
and max_seqlen_q is None
|
|
|
|
and max_seqlen_kv is None
|
|
|
|
and q_indices is None
|
|
|
|
and kv_indices is None
|
|
|
|
)
|
|
|
|
if attention_mask_type == AttnMaskType.CUSTOM:
|
|
|
|
assert not torch.all(attention_mask != 0, dim=-1).any()
|
|
|
|
elif attention_mask_type in (
|
|
|
|
AttnMaskType.PADDED,
|
|
|
|
AttnMaskType.PADDED_CAUSAL,
|
|
|
|
):
|
|
|
|
assert (
|
|
|
|
cu_seqlens_q is not None
|
|
|
|
and cu_seqlens_kv is not None
|
|
|
|
and max_seqlen_q is not None
|
|
|
|
and max_seqlen_kv is not None
|
|
|
|
and q_indices is not None
|
|
|
|
and kv_indices is not None
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# if attention_mask is None, attention_mask_type should be the default value
|
|
|
|
assert attention_mask_type == AttnMaskType.CUSTOM
|
|
|
|
# kernel dispatch
|
|
|
|
mask_type = attention_mask_type if attention_mask is not None else None
|
|
|
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
|
|
|
is_causal = attention_mask is not None and attention_mask_type in (
|
|
|
|
AttnMaskType.CAUSAL,
|
|
|
|
AttnMaskType.PADDED_CAUSAL,
|
|
|
|
)
|
|
|
|
return attn_func(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
dropout_p=dropout_p,
|
|
|
|
scale=scale,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
is_causal=is_causal,
|
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
|
cu_seqlens_kv=cu_seqlens_kv,
|
|
|
|
max_seqlen_q=max_seqlen_q,
|
|
|
|
max_seqlen_kv=max_seqlen_kv,
|
|
|
|
q_indices=q_indices,
|
|
|
|
kv_indices=kv_indices,
|
|
|
|
)
|