mirror of https://github.com/hpcaitech/ColossalAI
commit
37e35230ff
|
@ -119,6 +119,10 @@ class FlashAttentionLoader(KernelLoader):
|
|||
]
|
||||
|
||||
|
||||
class FlashAttentionDaoLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionDaoCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionDaoLoader,
|
||||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
|
@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger
|
|||
|
||||
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
||||
|
||||
MEMORY_BOUND = 10 * 1e9
|
||||
|
||||
__all__ = [
|
||||
"AttnMaskType",
|
||||
"ColoAttention",
|
||||
|
@ -77,6 +80,7 @@ def get_pad_info(
|
|||
|
||||
class ColoAttention:
|
||||
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||
|
||||
@staticmethod
|
||||
def _init_kernels_dispatch():
|
||||
|
@ -102,9 +106,11 @@ class ColoAttention:
|
|||
torch.bfloat16: half_dispatch_map,
|
||||
torch.float32: float_dispatch_map,
|
||||
}
|
||||
if ColoAttention._flash_kernel_dispatch is None:
|
||||
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
|
||||
|
||||
@staticmethod
|
||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
|
||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
||||
ColoAttention._init_kernels_dispatch()
|
||||
if (
|
||||
dtype not in ColoAttention._kernel_dispatch_map
|
||||
|
@ -113,12 +119,19 @@ class ColoAttention:
|
|||
raise ValueError(
|
||||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
||||
)
|
||||
|
||||
if size >= MEMORY_BOUND:
|
||||
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
|
||||
# lazy load
|
||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||
mask_type
|
||||
].load()
|
||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||
|
||||
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
||||
return ColoAttention._flash_kernel_dispatch
|
||||
else:
|
||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||
|
||||
@staticmethod
|
||||
def prepare_attn_kwargs(
|
||||
|
@ -154,6 +167,8 @@ class ColoAttention:
|
|||
return {}
|
||||
assert len(shape_4d) == 4 and shape_4d[1] == 1
|
||||
b, _, s_q, s_kv = shape_4d
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_size = s_q * s_kv * element_size
|
||||
outputs = {}
|
||||
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
|
||||
kv_padding_mask is None or kv_padding_mask.bool().all()
|
||||
|
@ -161,10 +176,13 @@ class ColoAttention:
|
|||
# no padding
|
||||
assert is_causal
|
||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||
if s_q != 1:
|
||||
attention_mask = attention_mask.tril(diagonal=0)
|
||||
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
||||
if memory_size < MEMORY_BOUND:
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||
if s_q != 1:
|
||||
attention_mask.tril_(diagonal=0)
|
||||
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
||||
else:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
if kv_padding_mask is None:
|
||||
|
@ -177,7 +195,6 @@ class ColoAttention:
|
|||
b,
|
||||
s_kv,
|
||||
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
|
@ -190,10 +207,17 @@ class ColoAttention:
|
|||
)
|
||||
if is_causal:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||
if s_q != 1:
|
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||
if memory_size < MEMORY_BOUND:
|
||||
if s_q != 1:
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||
else:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||
if memory_size < MEMORY_BOUND:
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
|
||||
if invert:
|
||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||
outputs["attention_mask"] = attention_mask
|
||||
|
@ -278,8 +302,12 @@ class ColoAttention:
|
|||
assert attention_mask_type == AttnMaskType.CUSTOM
|
||||
|
||||
# kernel dispatch
|
||||
b, _, s_q, _ = q.shape
|
||||
b, _, s_kv, _ = v.shape
|
||||
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||
memory_size = s_q * s_kv * element_size
|
||||
mask_type = attention_mask_type if attention_mask is not None else None
|
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
||||
is_causal = attention_mask is not None and attention_mask_type in (
|
||||
AttnMaskType.CAUSAL,
|
||||
AttnMaskType.PADDED_CAUSAL,
|
||||
|
|
Loading…
Reference in New Issue