diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 8890da242..a2ea761bf 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -173,7 +173,7 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - if memory_size < MEMORY_BOUND and not is_causal: + if memory_size < MEMORY_BOUND: attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: attention_mask.tril_(diagonal=0)