diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 6dab17ec0..5d1a30d8a 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -690,6 +690,13 @@ class RingAttention(torch.autograd.Function): ) return out, softmax_lse, rng_state + def _kv_comm(i): + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + def _local_ring_forward(): # (Hopefully) overlap output correction with next flash attn for i in range(local_sp_size): @@ -698,12 +705,8 @@ class RingAttention(torch.autograd.Function): # NOTE: waiting outside the current stream will NOT correctly synchronize. if i > 0: local_kv_comms[(i + 1) % 2].wait() - - # Avoid overwriting attn input when it shares mem with buffer - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - if i < local_sp_size - 1: - local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + if i == 0: + _kv_comm(i) if i == 0: # Compute with local KV; no mask @@ -734,6 +737,9 @@ class RingAttention(torch.autograd.Function): rng_states[i], ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + # Pipeline the next KV comm with output correction instead of the next flash attn + # to minimize idle time when comm takes longer than attn. + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() @@ -761,15 +767,13 @@ class RingAttention(torch.autograd.Function): # all new KVs from the previous inner ring for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - if i < local_sp_size - 1: - local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) - # Send & recv KV if i > 0: local_kv_comms[(i + 1) % 2].wait() + if i == 0: + _kv_comm(i) + if ring_num_idx > inter_ring_rank: kv_block = kv_buffers[i % 2] ( @@ -778,6 +782,8 @@ class RingAttention(torch.autograd.Function): rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q1, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) @@ -792,6 +798,8 @@ class RingAttention(torch.autograd.Function): rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() )