mirror of https://github.com/hpcaitech/ColossalAI
[coloattention]modify coloattention (#5627)
* modify coloattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fxi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5646/head
parent
7ee569b05f
commit
148506c828
|
@ -113,10 +113,6 @@ class FlashAttentionLoader(KernelLoader):
|
|||
]
|
||||
|
||||
|
||||
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import (
|
|||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
FlashAttentionWithPaddingMaskLoader,
|
||||
KernelLoader,
|
||||
)
|
||||
|
||||
|
@ -65,15 +64,17 @@ class ColoAttention:
|
|||
half_dispatch_map = {
|
||||
None: FlashAttentionLoader(),
|
||||
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
||||
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
|
||||
AttnMaskType.PADDED: FlashAttentionLoader(),
|
||||
AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
|
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
|
||||
}
|
||||
# fp32
|
||||
float_dispatch_map = {
|
||||
None: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||
}
|
||||
ColoAttention._kernel_dispatch_map = {
|
||||
torch.float16: half_dispatch_map,
|
||||
|
@ -140,16 +141,22 @@ class ColoAttention:
|
|||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
||||
else:
|
||||
assert q_padding_mask.shape == (
|
||||
b,
|
||||
s_q,
|
||||
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
if kv_padding_mask is None:
|
||||
# self attention
|
||||
kv_padding_mask = q_padding_mask
|
||||
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
|
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
|
||||
else:
|
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
||||
assert kv_padding_mask.shape == (
|
||||
b,
|
||||
s_kv,
|
||||
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
||||
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
|
|
|
@ -4,11 +4,7 @@ from copy import copy
|
|||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
FlashAttentionWithPaddingMaskLoader,
|
||||
)
|
||||
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer.attn import invert_mask
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype):
|
|||
if ext.is_available():
|
||||
ext.assert_compatible()
|
||||
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
|
||||
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
|
||||
ext = ext_cls()
|
||||
if ext.is_available():
|
||||
ext.assert_compatible()
|
||||
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
|
||||
|
||||
test_sets = {
|
||||
"none": (lambda dtype: ({}, None), avail_attn_funcs),
|
||||
|
|
Loading…
Reference in New Issue