overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/6022/head
Edenzzzz 2024-08-19 14:08:17 +08:00 committed by GitHub
parent 26493b97d3
commit f1c3266a94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 11 deletions

View File

@ -690,6 +690,13 @@ class RingAttention(torch.autograd.Function):
)
return out, softmax_lse, rng_state
def _kv_comm(i):
# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
@ -698,12 +705,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
if i == 0:
_kv_comm(i)
if i == 0:
# Compute with local KV; no mask
@ -734,6 +737,9 @@ class RingAttention(torch.autograd.Function):
rng_states[i],
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
@ -761,15 +767,13 @@ class RingAttention(torch.autograd.Function):
# all new KVs from the previous inner ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
if i == 0:
_kv_comm(i)
if ring_num_idx > inter_ring_rank:
kv_block = kv_buffers[i % 2]
(
@ -778,6 +782,8 @@ class RingAttention(torch.autograd.Function):
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
@ -792,6 +798,8 @@ class RingAttention(torch.autograd.Function):
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)