From f393867cff97924e0b90d81758a29bc5a2e94923 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 05:24:52 +0000 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c18d57de1..8890da242 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -173,13 +173,13 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size >= MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - else: + if memory_size < MEMORY_BOUND and not is_causal: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: attention_mask.tril_(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: @@ -208,11 +208,11 @@ class ColoAttention: ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if memory_size >= MEMORY_BOUND: - attention_mask = torch.empty((0,), dtype=dtype, device=device) - else: + if memory_size < MEMORY_BOUND: if s_q != 1: attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) else: outputs["attention_mask_type"] = AttnMaskType.PADDED if invert: