[feat] split communication and calculation; fix pop empty send_bwd_buffer error;

pull/6034/head
duanjunwen 2024-08-27 06:29:13 +00:00
parent 1d75045c37
commit 5e09c8b4e1
2 changed files with 75 additions and 85 deletions

View File

@ -176,7 +176,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_forward_buffer[model_chunk_id].append(None)
return None, []
################
@ -186,24 +185,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
else:
################
# chunk = 1 & is_last_stage
# get y from local_send_forward_buffer as input
# do nothing; cause u get y from local_send_forward_buffer in schedule f
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_tensor = self.local_send_forward_buffer.pop(0)
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, []
return None, []
################
# chunk = 1 & not is_last_stage
@ -212,10 +203,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
@ -236,14 +223,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# bwd chunk0 is right V;
################
# chunk = 0 & is_last_stage
# get dy from local recv_bwd_buffer
# do nothing; Already get dy from local_send_backward_buffer in schedule b
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, []
return None, []
################
# chunk = 0 & not is_last_stage
@ -252,9 +235,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
@ -265,20 +245,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_backward_buffer[model_chunk_id].append(None)
return None, []
################
# chunk = 1 & not is_first_stage
# self.comm.recv_backward recv bwd from prev stage;
# chunk = 1 & not first stage
# recv_backward recv bwd from prev stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
@ -296,14 +271,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
if model_chunk_id == 0:
################
# chunk = 0 && is_last_stage
# hold y on local_send_forward_buffer
# do nothing; hold y on local_send_forward_buffer
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(output_tensor)
return []
################
@ -312,15 +285,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
# send_metadata=self.send_tensor_metadata
# self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
else:
################
# chunk = 1 && is_first_stage
# do nothing; cause you are the last chunk on last stage;
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
@ -331,9 +303,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_tensor, prev_rank)
# send_metadata=self.send_tensor_metadata
# self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
@ -355,7 +326,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################
# chunk = 0 && is_first_stage
# do nothing; cause u are the first chunk in first stage; bwd end
# send input_tensor_grad to local buffer;
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
@ -365,21 +335,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to PREV stage;
################
else:
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
# send_metadata=self.send_grad_metadata
return send_handles
# bwd chunk1 is left V;
else:
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}")
################
# chunk = 1 && is_last_stage
# hold dy to local_send_bwd_buffer;
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
self.local_send_backward_buffer.append(input_tensor_grad)
return []
################
@ -387,14 +355,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to NEXT stage;
################
else:
print(
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
)
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
next_rank = self.stage_manager.get_next_rank()
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
# send_metadata=self.send_grad_metadata
return send_handles
def forward_step(
@ -519,20 +482,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs: Optional[List[Any]] = None,
):
# Step1: recv fwd
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # first layer
# input_obj = input_obj
# else:
# # other layer
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
# # print(f"recv input_obj {input_obj}")
# _wait_p2p(wait_handles)
if model_chunk_id == 0:
# is first stage; get input from func param
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = input_obj
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = input_obj
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_obj = self.local_send_forward_buffer.pop(0)
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Step2: fwd step
output_obj = self.forward_step(
@ -555,8 +518,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Step3: send fwd
# add output to send_fwd_buffer
self.send_forward_buffer[model_chunk_id].append(output_obj)
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
if model_chunk_id == 0:
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(output_obj)
def schedule_b(
self,
@ -569,14 +542,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# output_obj_grad: Optional[dict],
):
# Step1: recv bwd
# # not first stage and chunk 1
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# output_tensor_grad, recv_bwd_handles = None, []
# # print(f"recv output_tensor_grad {output_tensor_grad}")
# else:
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# # print(f"recv output_tensor_grad {output_tensor_grad}")
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
@ -593,11 +572,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# _wait_p2p(recv_bwd_handles)
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
# Step2: bwd step
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
@ -609,8 +584,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
# Step3: send bwd
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
if model_chunk_id == 0:
# do nothing; end of bwd;
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
# save input_object_grad to send_backward_buffer
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
else:
# send to local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(input_object_grad)
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
def schedule_w(
self,
@ -644,9 +631,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
):
it = self.it
# while we still have schedules_node in self.schedules
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
while it < len(self.schedules):
scheduled_node = self.schedules[it]
print(f"it {it}; scheduled_node {scheduled_node};")
print(
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
)
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
if scheduled_node.type == "RECV_FORWARD":

View File

@ -486,7 +486,7 @@ def test_run_fwd_bwd_base(
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
],
# stage 2
[
@ -547,7 +547,7 @@ def test_run_fwd_bwd_base(
# init model and input
num_layers = 8
in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
# print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
@ -578,9 +578,9 @@ def test_run_fwd_bwd_base(
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
# print(
# f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
# )
torch.cuda.synchronize()
scheduler.run_forward_backward(