mirror of https://github.com/hpcaitech/ColossalAI
fix the sp
parent
a35a078f08
commit
fdd84b9087
|
@ -118,6 +118,8 @@ class FlashAttentionLoader(KernelLoader):
|
||||||
FlashAttentionSdpaCudaExtension,
|
FlashAttentionSdpaCudaExtension,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class FlashAttentionDaoLoader(KernelLoader):
|
||||||
|
REGISTRY = [FlashAttentionDaoCudaExtension]
|
||||||
|
|
||||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||||
|
|
|
@ -10,6 +10,7 @@ from einops import rearrange
|
||||||
from colossalai.kernel.kernel_loader import (
|
from colossalai.kernel.kernel_loader import (
|
||||||
FlashAttentionForFloatAndCustomMaskLoader,
|
FlashAttentionForFloatAndCustomMaskLoader,
|
||||||
FlashAttentionLoader,
|
FlashAttentionLoader,
|
||||||
|
FlashAttentionDaoLoader,
|
||||||
FlashAttentionWithCustomMaskLoader,
|
FlashAttentionWithCustomMaskLoader,
|
||||||
KernelLoader,
|
KernelLoader,
|
||||||
)
|
)
|
||||||
|
@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
||||||
|
|
||||||
|
MEMORY_BOUND = 10 * 1e9
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AttnMaskType",
|
"AttnMaskType",
|
||||||
"ColoAttention",
|
"ColoAttention",
|
||||||
|
@ -104,7 +107,7 @@ class ColoAttention:
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
|
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
||||||
ColoAttention._init_kernels_dispatch()
|
ColoAttention._init_kernels_dispatch()
|
||||||
if (
|
if (
|
||||||
dtype not in ColoAttention._kernel_dispatch_map
|
dtype not in ColoAttention._kernel_dispatch_map
|
||||||
|
@ -113,12 +116,16 @@ class ColoAttention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if size > MEMORY_BOUND:
|
||||||
|
FlashAttentionDaoLoader().load()
|
||||||
# lazy load
|
# lazy load
|
||||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||||
mask_type
|
mask_type
|
||||||
].load()
|
].load()
|
||||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
|
||||||
|
return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_attn_kwargs(
|
def prepare_attn_kwargs(
|
||||||
|
@ -163,7 +170,7 @@ class ColoAttention:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||||
if s_q != 1:
|
if s_q != 1:
|
||||||
attention_mask = attention_mask.tril(diagonal=0)
|
attention_mask.tril_(diagonal=0)
|
||||||
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
||||||
else:
|
else:
|
||||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||||
|
@ -197,6 +204,15 @@ class ColoAttention:
|
||||||
if invert:
|
if invert:
|
||||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||||
outputs["attention_mask"] = attention_mask
|
outputs["attention_mask"] = attention_mask
|
||||||
|
|
||||||
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
memory_size = (s_q * s_kv * element_size)
|
||||||
|
if memory_size > MEMORY_BOUND:
|
||||||
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
|
outputs["attention_mask"] = attention_mask
|
||||||
|
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
|
||||||
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -278,8 +294,16 @@ class ColoAttention:
|
||||||
assert attention_mask_type == AttnMaskType.CUSTOM
|
assert attention_mask_type == AttnMaskType.CUSTOM
|
||||||
|
|
||||||
# kernel dispatch
|
# kernel dispatch
|
||||||
|
b, _, s_q, _ = q.shape
|
||||||
|
b, _, s_kv, _ = v.shape
|
||||||
|
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||||
|
memory_size = (s_q * s_kv * element_size)
|
||||||
|
if memory_size > MEMORY_BOUND:
|
||||||
|
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
|
||||||
|
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
||||||
|
|
||||||
mask_type = attention_mask_type if attention_mask is not None else None
|
mask_type = attention_mask_type if attention_mask is not None else None
|
||||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
||||||
is_causal = attention_mask is not None and attention_mask_type in (
|
is_causal = attention_mask is not None and attention_mask_type in (
|
||||||
AttnMaskType.CAUSAL,
|
AttnMaskType.CAUSAL,
|
||||||
AttnMaskType.PADDED_CAUSAL,
|
AttnMaskType.PADDED_CAUSAL,
|
||||||
|
|
Loading…
Reference in New Issue