diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 6dab17ec0..5d1a30d8a 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -690,6 +690,13 @@ class RingAttention(torch.autograd.Function):
             )
             return out, softmax_lse, rng_state
 
+        def _kv_comm(i):
+            # Avoid overwriting attn input when it shares mem with buffer
+            if not RingAttention.ATTN_DONE.query():
+                kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
+            if i < local_sp_size - 1:
+                local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+
         def _local_ring_forward():
             # (Hopefully) overlap output correction with next flash attn
             for i in range(local_sp_size):
@@ -698,12 +705,8 @@ class RingAttention(torch.autograd.Function):
                     # NOTE: waiting outside the current stream will NOT correctly synchronize.
                     if i > 0:
                         local_kv_comms[(i + 1) % 2].wait()
-
-                    # Avoid overwriting attn input when it shares mem with buffer
-                    if not RingAttention.ATTN_DONE.query():
-                        kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
-                    if i < local_sp_size - 1:
-                        local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+                    if i == 0:
+                        _kv_comm(i)
 
                     if i == 0:
                         # Compute with local KV; no mask
@@ -734,6 +737,9 @@ class RingAttention(torch.autograd.Function):
                             rng_states[i],
                         ) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
                     RingAttention.ATTN_DONE.record()
+                    # Pipeline the next KV comm with output correction instead of the next flash attn
+                    # to minimize idle time when comm takes longer than attn.
+                    _kv_comm(i + 1)
 
                     block_softmax_lse[i % 2] = (
                         block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
@@ -761,15 +767,13 @@ class RingAttention(torch.autograd.Function):
             # all new KVs from the previous inner ring
             for i in range(local_sp_size):
                 with torch.cuda.stream(sp_streams[i % 2]):
-                    if not RingAttention.ATTN_DONE.query():
-                        kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
-                    if i < local_sp_size - 1:
-                        local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
-
                     # Send & recv KV
                     if i > 0:
                         local_kv_comms[(i + 1) % 2].wait()
 
+                    if i == 0:
+                        _kv_comm(i)
+
                     if ring_num_idx > inter_ring_rank:
                         kv_block = kv_buffers[i % 2]
                         (
@@ -778,6 +782,8 @@ class RingAttention(torch.autograd.Function):
                             rng_states[i + local_sp_size * ring_num_idx],
                         ) = _forward(q1, kv_block[0], kv_block[1], causal=False)
                         RingAttention.ATTN_DONE.record()
+
+                        _kv_comm(i + 1)
                         block_softmax_lse[i % 2] = (
                             block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
                         )
@@ -792,6 +798,8 @@ class RingAttention(torch.autograd.Function):
                             rng_states[i + local_sp_size * ring_num_idx],
                         ) = _forward(q, kv_block[0], kv_block[1], causal=False)
                         RingAttention.ATTN_DONE.record()
+
+                        _kv_comm(i + 1)
                         block_softmax_lse[i % 2] = (
                             block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
                         )