mirror of https://github.com/hpcaitech/ColossalAI
[coloattention]modify coloattention (#5627)
* modify coloattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fxi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5646/head
parent
7ee569b05f
commit
148506c828
|
@ -113,10 +113,6 @@ class FlashAttentionLoader(KernelLoader):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
|
|
||||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import (
|
||||||
FlashAttentionForFloatAndCustomMaskLoader,
|
FlashAttentionForFloatAndCustomMaskLoader,
|
||||||
FlashAttentionLoader,
|
FlashAttentionLoader,
|
||||||
FlashAttentionWithCustomMaskLoader,
|
FlashAttentionWithCustomMaskLoader,
|
||||||
FlashAttentionWithPaddingMaskLoader,
|
|
||||||
KernelLoader,
|
KernelLoader,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,15 +64,17 @@ class ColoAttention:
|
||||||
half_dispatch_map = {
|
half_dispatch_map = {
|
||||||
None: FlashAttentionLoader(),
|
None: FlashAttentionLoader(),
|
||||||
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
|
||||||
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
|
AttnMaskType.PADDED: FlashAttentionLoader(),
|
||||||
AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
AttnMaskType.CAUSAL: FlashAttentionLoader(),
|
||||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
|
AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
|
||||||
}
|
}
|
||||||
# fp32
|
# fp32
|
||||||
float_dispatch_map = {
|
float_dispatch_map = {
|
||||||
None: FlashAttentionForFloatAndCustomMaskLoader(),
|
None: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||||
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||||
|
AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||||
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||||
|
AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
|
||||||
}
|
}
|
||||||
ColoAttention._kernel_dispatch_map = {
|
ColoAttention._kernel_dispatch_map = {
|
||||||
torch.float16: half_dispatch_map,
|
torch.float16: half_dispatch_map,
|
||||||
|
@ -140,16 +141,22 @@ class ColoAttention:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
||||||
else:
|
else:
|
||||||
|
assert q_padding_mask.shape == (
|
||||||
|
b,
|
||||||
|
s_q,
|
||||||
|
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
|
||||||
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||||
if kv_padding_mask is None:
|
if kv_padding_mask is None:
|
||||||
# self attention
|
# self attention
|
||||||
kv_padding_mask = q_padding_mask
|
kv_padding_mask = q_padding_mask
|
||||||
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
|
max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
|
||||||
|
else:
|
||||||
|
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
||||||
|
assert kv_padding_mask.shape == (
|
||||||
b,
|
b,
|
||||||
s_kv,
|
s_kv,
|
||||||
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||||
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
|
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
|
||||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
|
||||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
|
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"cu_seqlens_q": cu_seqlens_q,
|
"cu_seqlens_q": cu_seqlens_q,
|
||||||
|
|
|
@ -4,11 +4,7 @@ from copy import copy
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
from colossalai.kernel.kernel_loader import (
|
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
|
||||||
FlashAttentionLoader,
|
|
||||||
FlashAttentionWithCustomMaskLoader,
|
|
||||||
FlashAttentionWithPaddingMaskLoader,
|
|
||||||
)
|
|
||||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||||
from colossalai.shardformer.layer.attn import invert_mask
|
from colossalai.shardformer.layer.attn import invert_mask
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
from colossalai.testing import clear_cache_before_run, parameterize
|
||||||
|
@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype):
|
||||||
if ext.is_available():
|
if ext.is_available():
|
||||||
ext.assert_compatible()
|
ext.assert_compatible()
|
||||||
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
|
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
|
||||||
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
|
|
||||||
ext = ext_cls()
|
|
||||||
if ext.is_available():
|
|
||||||
ext.assert_compatible()
|
|
||||||
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
|
|
||||||
|
|
||||||
test_sets = {
|
test_sets = {
|
||||||
"none": (lambda dtype: ({}, None), avail_attn_funcs),
|
"none": (lambda dtype: ({}, None), avail_attn_funcs),
|
||||||
|
|
Loading…
Reference in New Issue