From dc032172c34538abcdad101997b9637b70ef0552 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 13 Sep 2024 06:00:58 +0000 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 8890da242..a2ea761bf 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -173,7 +173,7 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size < MEMORY_BOUND and not is_causal: + if memory_size < MEMORY_BOUND: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: attention_mask.tril_(diagonal=0)