ColossalAI/colossalai/shardformer/layer/attn.py

277 lines
13 KiB
Python
Raw Normal View History

from enum import Enum
from typing import Callable, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
KernelLoader,
)
__all__ = [
"AttnMaskType",
"ColoAttention",
]
class AttnMaskType(Enum):
CUSTOM = 0
PADDED = 1
CAUSAL = 2
PADDED_CAUSAL = 3
def invert_mask(mask: torch.Tensor) -> torch.Tensor:
"""Invert the mask tensor.
Args:
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
Returns:
torch.Tensor: Inverted mask tensor.
"""
inverted_mask = 1.0 - mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Get padding information from padding mask.
Args:
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
Returns:
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
"""
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices
class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
@staticmethod
def _init_kernels_dispatch():
if ColoAttention._kernel_dispatch_map is None:
# fp16/bf16
half_dispatch_map = {
None: FlashAttentionLoader(),
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
AttnMaskType.PADDED: FlashAttentionLoader(),
AttnMaskType.CAUSAL: FlashAttentionLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
}
# fp32
float_dispatch_map = {
None: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
}
ColoAttention._kernel_dispatch_map = {
torch.float16: half_dispatch_map,
torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map,
}
@staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
ColoAttention._init_kernels_dispatch()
if (
dtype not in ColoAttention._kernel_dispatch_map
or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
):
raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type
].load()
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
@staticmethod
def prepare_attn_kwargs(
shape_4d: Tuple[int],
dtype: torch.dtype,
device: torch.device,
q_padding_mask: Optional[torch.Tensor] = None,
kv_padding_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
Args:
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
Returns:
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
"""
if q_padding_mask is None and not is_causal:
return {}
assert len(shape_4d) == 4 and shape_4d[1] == 1
b, _, s_q, s_kv = shape_4d
outputs = {}
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
kv_padding_mask is None or kv_padding_mask.bool().all()
):
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
else:
assert q_padding_mask.shape == (
b,
s_q,
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
kv_padding_mask = q_padding_mask
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
else:
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
assert kv_padding_mask.shape == (
b,
s_kv,
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
"cu_seqlens_kv": cu_seqlens_kv,
"max_seqlen_q": max_seqlen_q,
"max_seqlen_kv": max_seqlen_kv,
"q_indices": q_indices,
"kv_indices": kv_indices,
}
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
return outputs
@staticmethod
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
"""Flash Attention function. It supports 4 mask type.
1. custom mask: recv attention_mask
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
3. causal mask: recv attention_mask, attention_mask_type
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
Args:
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Shape should be [B+1]. Defaults to None.
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
Shape should be [B+1]. Defaults to None.
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
Shape should be [NUM_TOKENS]. Defaults to None.
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
Returns:
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
"""
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
# this case is usaul when padding mask is used and self attention is performed
# thus, we don't use sdpa when padding mask is used
# sanity check
if attention_mask is not None:
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
assert (
cu_seqlens_q is None
and cu_seqlens_kv is None
and max_seqlen_q is None
and max_seqlen_kv is None
and q_indices is None
and kv_indices is None
)
if attention_mask_type == AttnMaskType.CUSTOM:
assert not torch.all(attention_mask != 0, dim=-1).any()
elif attention_mask_type in (
AttnMaskType.PADDED,
AttnMaskType.PADDED_CAUSAL,
):
assert (
cu_seqlens_q is not None
and cu_seqlens_kv is not None
and max_seqlen_q is not None
and max_seqlen_kv is not None
and q_indices is not None
and kv_indices is not None
)
else:
# if attention_mask is None, attention_mask_type should be the default value
assert attention_mask_type == AttnMaskType.CUSTOM
# kernel dispatch
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
is_causal = attention_mask is not None and attention_mask_type in (
AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL,
)
return attn_func(
q,
k,
v,
dropout_p=dropout_p,
scale=scale,
attention_mask=attention_mask,
is_causal=is_causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
q_indices=q_indices,
kv_indices=kv_indices,
)