From 6eb8832366c76187350059985d780acebbcd9a2d Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 05:06:56 +0000 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 65051a61f..c18d57de1 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -117,7 +117,7 @@ class ColoAttention: "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) - if size > MEMORY_BOUND: + if size >= MEMORY_BOUND: FlashAttentionDaoLoader().load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): @@ -173,7 +173,7 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size > MEMORY_BOUND: + if memory_size >= MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) @@ -192,10 +192,10 @@ class ColoAttention: b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - if memory_size > MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - else: + if memory_size < MEMORY_BOUND and not is_causal: attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -208,7 +208,7 @@ class ColoAttention: ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if memory_size > MEMORY_BOUND: + if memory_size >= MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: if s_q != 1: @@ -303,9 +303,6 @@ class ColoAttention: b, _, s_kv, _ = v.shape element_size = torch.tensor([], dtype=q.dtype).element_size() memory_size = s_q * s_kv * element_size - if memory_size > MEMORY_BOUND: - assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED - mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in (