mirror of https://github.com/hpcaitech/ColossalAI
overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/6022/head
parent
26493b97d3
commit
f1c3266a94
|
@ -690,6 +690,13 @@ class RingAttention(torch.autograd.Function):
|
|||
)
|
||||
return out, softmax_lse, rng_state
|
||||
|
||||
def _kv_comm(i):
|
||||
# Avoid overwriting attn input when it shares mem with buffer
|
||||
if not RingAttention.ATTN_DONE.query():
|
||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||
if i < local_sp_size - 1:
|
||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
def _local_ring_forward():
|
||||
# (Hopefully) overlap output correction with next flash attn
|
||||
for i in range(local_sp_size):
|
||||
|
@ -698,12 +705,8 @@ class RingAttention(torch.autograd.Function):
|
|||
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||
if i > 0:
|
||||
local_kv_comms[(i + 1) % 2].wait()
|
||||
|
||||
# Avoid overwriting attn input when it shares mem with buffer
|
||||
if not RingAttention.ATTN_DONE.query():
|
||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||
if i < local_sp_size - 1:
|
||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
if i == 0:
|
||||
_kv_comm(i)
|
||||
|
||||
if i == 0:
|
||||
# Compute with local KV; no mask
|
||||
|
@ -734,6 +737,9 @@ class RingAttention(torch.autograd.Function):
|
|||
rng_states[i],
|
||||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
# Pipeline the next KV comm with output correction instead of the next flash attn
|
||||
# to minimize idle time when comm takes longer than attn.
|
||||
_kv_comm(i + 1)
|
||||
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
|
@ -761,15 +767,13 @@ class RingAttention(torch.autograd.Function):
|
|||
# all new KVs from the previous inner ring
|
||||
for i in range(local_sp_size):
|
||||
with torch.cuda.stream(sp_streams[i % 2]):
|
||||
if not RingAttention.ATTN_DONE.query():
|
||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||
if i < local_sp_size - 1:
|
||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||
|
||||
# Send & recv KV
|
||||
if i > 0:
|
||||
local_kv_comms[(i + 1) % 2].wait()
|
||||
|
||||
if i == 0:
|
||||
_kv_comm(i)
|
||||
|
||||
if ring_num_idx > inter_ring_rank:
|
||||
kv_block = kv_buffers[i % 2]
|
||||
(
|
||||
|
@ -778,6 +782,8 @@ class RingAttention(torch.autograd.Function):
|
|||
rng_states[i + local_sp_size * ring_num_idx],
|
||||
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
|
||||
_kv_comm(i + 1)
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
)
|
||||
|
@ -792,6 +798,8 @@ class RingAttention(torch.autograd.Function):
|
|||
rng_states[i + local_sp_size * ring_num_idx],
|
||||
) = _forward(q, kv_block[0], kv_block[1], causal=False)
|
||||
RingAttention.ATTN_DONE.record()
|
||||
|
||||
_kv_comm(i + 1)
|
||||
block_softmax_lse[i % 2] = (
|
||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue