From 0a01e2a453abaa802cb839f67f7193af52709350 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 03:33:08 +0000 Subject: [PATCH] fix the attn --- colossalai/shardformer/layer/attn.py | 44 +++++++++++++++------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 0a4f98535..1ffbae73e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,7 +18,7 @@ from colossalai.logging import get_dist_logger from .utils import RingComm, get_half_index, split_varlen_zigzag -MEMORY_BOUND = 10 * 1e9 +MEMORY_BOUND = 1 * 1e9 __all__ = [ "AttnMaskType", @@ -125,9 +125,10 @@ class ColoAttention: mask_type ].load() - return ( - FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] - ) + if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): + return FlashAttentionDaoLoader() + else: + return ColoAttention._kernel_dispatch_map[dtype][mask_type] @staticmethod def prepare_attn_kwargs( @@ -163,6 +164,8 @@ class ColoAttention: return {} assert len(shape_4d) == 4 and shape_4d[1] == 1 b, _, s_q, s_kv = shape_4d + element_size = torch.tensor([], dtype=dtype).element_size() + memory_size = s_q * s_kv * element_size outputs = {} if (q_padding_mask is None or q_padding_mask.bool().all()) and ( kv_padding_mask is None or kv_padding_mask.bool().all() @@ -170,10 +173,13 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) - if s_q != 1: - attention_mask.tril_(diagonal=0) - attention_mask = attention_mask.expand(b, s_q, s_kv) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + else: + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) + if s_q != 1: + attention_mask.tril_(diagonal=0) + attention_mask = attention_mask.expand(b, s_q, s_kv) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: @@ -186,7 +192,10 @@ class ColoAttention: b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + else: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -199,22 +208,16 @@ class ColoAttention: ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if s_q != 1: - attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + else: + if s_q != 1: + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask - - element_size = torch.tensor([], dtype=dtype).element_size() - memory_size = s_q * s_kv * element_size - if memory_size > MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - outputs["attention_mask"] = attention_mask - if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL: - outputs["attention_mask_type"] = AttnMaskType.CAUSAL - return outputs @staticmethod @@ -301,7 +304,6 @@ class ColoAttention: element_size = torch.tensor([], dtype=q.dtype).element_size() memory_size = s_q * s_kv * element_size if memory_size > MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device) assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED mask_type = attention_mask_type if attention_mask is not None else None