fix the sp

pull/6061/head
wangbluo 2024-09-13 02:32:03 +00:00
parent a35a078f08
commit fdd84b9087
2 changed files with 30 additions and 4 deletions

View File

@ -118,6 +118,8 @@ class FlashAttentionLoader(KernelLoader):
FlashAttentionSdpaCudaExtension, FlashAttentionSdpaCudaExtension,
] ]
class FlashAttentionDaoLoader(KernelLoader):
REGISTRY = [FlashAttentionDaoCudaExtension]
class FlashAttentionWithCustomMaskLoader(KernelLoader): class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]

View File

@ -10,6 +10,7 @@ from einops import rearrange
from colossalai.kernel.kernel_loader import ( from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader, FlashAttentionLoader,
FlashAttentionDaoLoader,
FlashAttentionWithCustomMaskLoader, FlashAttentionWithCustomMaskLoader,
KernelLoader, KernelLoader,
) )
@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger
from .utils import RingComm, get_half_index, split_varlen_zigzag from .utils import RingComm, get_half_index, split_varlen_zigzag
MEMORY_BOUND = 10 * 1e9
__all__ = [ __all__ = [
"AttnMaskType", "AttnMaskType",
"ColoAttention", "ColoAttention",
@ -104,7 +107,7 @@ class ColoAttention:
} }
@staticmethod @staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
ColoAttention._init_kernels_dispatch() ColoAttention._init_kernels_dispatch()
if ( if (
dtype not in ColoAttention._kernel_dispatch_map dtype not in ColoAttention._kernel_dispatch_map
@ -113,12 +116,16 @@ class ColoAttention:
raise ValueError( raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
) )
if size > MEMORY_BOUND:
FlashAttentionDaoLoader().load()
# lazy load # lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type mask_type
].load() ].load()
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
@staticmethod @staticmethod
def prepare_attn_kwargs( def prepare_attn_kwargs(
@ -163,7 +170,7 @@ class ColoAttention:
outputs["attention_mask_type"] = AttnMaskType.CAUSAL outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1: if s_q != 1:
attention_mask = attention_mask.tril(diagonal=0) attention_mask.tril_(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv) attention_mask = attention_mask.expand(b, s_q, s_kv)
else: else:
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
@ -197,6 +204,15 @@ class ColoAttention:
if invert: if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1) attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask outputs["attention_mask"] = attention_mask
element_size = torch.tensor([], dtype=dtype).element_size()
memory_size = (s_q * s_kv * element_size)
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
outputs["attention_mask"] = attention_mask
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
return outputs return outputs
@staticmethod @staticmethod
@ -278,8 +294,16 @@ class ColoAttention:
assert attention_mask_type == AttnMaskType.CUSTOM assert attention_mask_type == AttnMaskType.CUSTOM
# kernel dispatch # kernel dispatch
b, _, s_q, _ = q.shape
b, _, s_kv, _ = v.shape
element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = (s_q * s_kv * element_size)
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
mask_type = attention_mask_type if attention_mask is not None else None mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
is_causal = attention_mask is not None and attention_mask_type in ( is_causal = attention_mask is not None and attention_mask_type in (
AttnMaskType.CAUSAL, AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL, AttnMaskType.PADDED_CAUSAL,