diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 65051a61f..c18d57de1 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -117,7 +117,7 @@ class ColoAttention:
                 "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
             )
 
-        if size > MEMORY_BOUND:
+        if size >= MEMORY_BOUND:
             FlashAttentionDaoLoader().load()
         # lazy load
         if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
@@ -173,7 +173,7 @@ class ColoAttention:
             # no padding
             assert is_causal
             outputs["attention_mask_type"] = AttnMaskType.CAUSAL
-            if memory_size > MEMORY_BOUND:
+            if memory_size >= MEMORY_BOUND:
                 attention_mask = torch.empty((0,), dtype=dtype, device=device)
             else:
                 attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
@@ -192,10 +192,10 @@ class ColoAttention:
                 b,
                 s_kv,
             ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
-            if memory_size > MEMORY_BOUND:
-                attention_mask = torch.empty((0,), dtype=dtype, device=device)
-            else:
+            if memory_size < MEMORY_BOUND and not is_causal:
                 attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
+            else:
+                attention_mask = torch.empty((0,), dtype=dtype, device=device)
             outputs.update(
                 {
                     "cu_seqlens_q": cu_seqlens_q,
@@ -208,7 +208,7 @@ class ColoAttention:
             )
             if is_causal:
                 outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
-                if memory_size > MEMORY_BOUND:
+                if memory_size >= MEMORY_BOUND:
                     attention_mask = torch.empty((0,), dtype=dtype, device=device)
                 else:
                     if s_q != 1:
@@ -303,9 +303,6 @@ class ColoAttention:
         b, _, s_kv, _ = v.shape
         element_size = torch.tensor([], dtype=q.dtype).element_size()
         memory_size = s_q * s_kv * element_size
-        if memory_size > MEMORY_BOUND:
-            assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
-
         mask_type = attention_mask_type if attention_mask is not None else None
         attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
         is_causal = attention_mask is not None and attention_mask_type in (