diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 353e29b3d..2dff3bcbc 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -113,10 +113,6 @@ class FlashAttentionLoader(KernelLoader): ] -class FlashAttentionWithPaddingMaskLoader(KernelLoader): - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] - - class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index f3f6e59d3..abc865a34 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import ( FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, FlashAttentionWithCustomMaskLoader, - FlashAttentionWithPaddingMaskLoader, KernelLoader, ) @@ -65,15 +64,17 @@ class ColoAttention: half_dispatch_map = { None: FlashAttentionLoader(), AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(), - AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(), + AttnMaskType.PADDED: FlashAttentionLoader(), AttnMaskType.CAUSAL: FlashAttentionLoader(), - AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(), + AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(), } # fp32 float_dispatch_map = { None: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(), AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), } ColoAttention._kernel_dispatch_map = { torch.float16: half_dispatch_map, @@ -140,16 +141,22 @@ class ColoAttention: outputs["attention_mask_type"] = AttnMaskType.CAUSAL attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) else: + assert q_padding_mask.shape == ( + b, + s_q, + ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})" + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention kv_padding_mask = q_padding_mask - assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == ( + max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices + else: + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + assert kv_padding_mask.shape == ( b, s_kv, - ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" - attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) - max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) - max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" + attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py index f9eab132f..9aa24a166 100644 --- a/tests/test_shardformer/test_flash_attention.py +++ b/tests/test_shardformer/test_flash_attention.py @@ -4,11 +4,7 @@ from copy import copy import torch from torch.testing import assert_close -from colossalai.kernel.kernel_loader import ( - FlashAttentionLoader, - FlashAttentionWithCustomMaskLoader, - FlashAttentionWithPaddingMaskLoader, -) +from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer.attn import invert_mask from colossalai.testing import clear_cache_before_run, parameterize @@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype): if ext.is_available(): ext.assert_compatible() avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True)) - for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY: - ext = ext_cls() - if ext.is_available(): - ext.assert_compatible() - avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True)) test_sets = { "none": (lambda dtype: ({}, None), avail_attn_funcs),