|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
from colossalai.kernel.kernel_loader import ( |
|
|
|
|
FlashAttentionDaoLoader, |
|
|
|
|
FlashAttentionForFloatAndCustomMaskLoader, |
|
|
|
|
FlashAttentionLoader, |
|
|
|
|
FlashAttentionWithCustomMaskLoader, |
|
|
|
@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
|
|
from .utils import RingComm, get_half_index, split_varlen_zigzag |
|
|
|
|
|
|
|
|
|
MEMORY_BOUND = 10 * 1e9 |
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
|
"AttnMaskType", |
|
|
|
|
"ColoAttention", |
|
|
|
@ -77,6 +80,7 @@ def get_pad_info(
|
|
|
|
|
|
|
|
|
|
class ColoAttention: |
|
|
|
|
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None |
|
|
|
|
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _init_kernels_dispatch(): |
|
|
|
@ -102,9 +106,11 @@ class ColoAttention:
|
|
|
|
|
torch.bfloat16: half_dispatch_map, |
|
|
|
|
torch.float32: float_dispatch_map, |
|
|
|
|
} |
|
|
|
|
if ColoAttention._flash_kernel_dispatch is None: |
|
|
|
|
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader() |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: |
|
|
|
|
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: |
|
|
|
|
ColoAttention._init_kernels_dispatch() |
|
|
|
|
if ( |
|
|
|
|
dtype not in ColoAttention._kernel_dispatch_map |
|
|
|
@ -113,12 +119,19 @@ class ColoAttention:
|
|
|
|
|
raise ValueError( |
|
|
|
|
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if size >= MEMORY_BOUND: |
|
|
|
|
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() |
|
|
|
|
# lazy load |
|
|
|
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): |
|
|
|
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ |
|
|
|
|
mask_type |
|
|
|
|
].load() |
|
|
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type] |
|
|
|
|
|
|
|
|
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): |
|
|
|
|
return ColoAttention._flash_kernel_dispatch |
|
|
|
|
else: |
|
|
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type] |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def prepare_attn_kwargs( |
|
|
|
@ -154,6 +167,8 @@ class ColoAttention:
|
|
|
|
|
return {} |
|
|
|
|
assert len(shape_4d) == 4 and shape_4d[1] == 1 |
|
|
|
|
b, _, s_q, s_kv = shape_4d |
|
|
|
|
element_size = torch.tensor([], dtype=dtype).element_size() |
|
|
|
|
memory_size = s_q * s_kv * element_size |
|
|
|
|
outputs = {} |
|
|
|
|
if (q_padding_mask is None or q_padding_mask.bool().all()) and ( |
|
|
|
|
kv_padding_mask is None or kv_padding_mask.bool().all() |
|
|
|
@ -161,10 +176,13 @@ class ColoAttention:
|
|
|
|
|
# no padding |
|
|
|
|
assert is_causal |
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL |
|
|
|
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) |
|
|
|
|
if s_q != 1: |
|
|
|
|
attention_mask = attention_mask.tril(diagonal=0) |
|
|
|
|
attention_mask = attention_mask.expand(b, s_q, s_kv) |
|
|
|
|
if memory_size < MEMORY_BOUND: |
|
|
|
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) |
|
|
|
|
if s_q != 1: |
|
|
|
|
attention_mask.tril_(diagonal=0) |
|
|
|
|
attention_mask = attention_mask.expand(b, s_q, s_kv) |
|
|
|
|
else: |
|
|
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device) |
|
|
|
|
else: |
|
|
|
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) |
|
|
|
|
if kv_padding_mask is None: |
|
|
|
@ -177,7 +195,6 @@ class ColoAttention:
|
|
|
|
|
b, |
|
|
|
|
s_kv, |
|
|
|
|
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" |
|
|
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) |
|
|
|
|
outputs.update( |
|
|
|
|
{ |
|
|
|
|
"cu_seqlens_q": cu_seqlens_q, |
|
|
|
@ -190,10 +207,17 @@ class ColoAttention:
|
|
|
|
|
) |
|
|
|
|
if is_causal: |
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL |
|
|
|
|
if s_q != 1: |
|
|
|
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) |
|
|
|
|
if memory_size < MEMORY_BOUND: |
|
|
|
|
if s_q != 1: |
|
|
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) |
|
|
|
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) |
|
|
|
|
else: |
|
|
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device) |
|
|
|
|
else: |
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED |
|
|
|
|
if memory_size < MEMORY_BOUND: |
|
|
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
|
if invert: |
|
|
|
|
attention_mask = invert_mask(attention_mask).unsqueeze(1) |
|
|
|
|
outputs["attention_mask"] = attention_mask |
|
|
|
@ -278,8 +302,12 @@ class ColoAttention:
|
|
|
|
|
assert attention_mask_type == AttnMaskType.CUSTOM |
|
|
|
|
|
|
|
|
|
# kernel dispatch |
|
|
|
|
b, _, s_q, _ = q.shape |
|
|
|
|
b, _, s_kv, _ = v.shape |
|
|
|
|
element_size = torch.tensor([], dtype=q.dtype).element_size() |
|
|
|
|
memory_size = s_q * s_kv * element_size |
|
|
|
|
mask_type = attention_mask_type if attention_mask is not None else None |
|
|
|
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) |
|
|
|
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) |
|
|
|
|
is_causal = attention_mask is not None and attention_mask_type in ( |
|
|
|
|
AttnMaskType.CAUSAL, |
|
|
|
|
AttnMaskType.PADDED_CAUSAL, |
|
|
|
|