mirror of https://github.com/hpcaitech/ColossalAI
[feat] split communication and calculation; fix pop empty send_bwd_buffer error;
parent
1d75045c37
commit
5e09c8b4e1
|
@ -176,7 +176,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||
#################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.recv_forward_buffer[model_chunk_id].append(None)
|
||||
return None, []
|
||||
|
||||
################
|
||||
|
@ -186,24 +185,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
|
||||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
else:
|
||||
################
|
||||
# chunk = 1 & is_last_stage
|
||||
# get y from local_send_forward_buffer as input
|
||||
# do nothing; cause u get y from local_send_forward_buffer in schedule f
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_tensor = self.local_send_forward_buffer.pop(0)
|
||||
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, []
|
||||
return None, []
|
||||
|
||||
################
|
||||
# chunk = 1 & not is_last_stage
|
||||
|
@ -212,10 +203,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
|
||||
|
||||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
|
@ -236,14 +223,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# bwd chunk0 is right V;
|
||||
################
|
||||
# chunk = 0 & is_last_stage
|
||||
# get dy from local recv_bwd_buffer
|
||||
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, []
|
||||
return None, []
|
||||
|
||||
################
|
||||
# chunk = 0 & not is_last_stage
|
||||
|
@ -252,9 +235,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
|
@ -265,20 +245,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# do nothing; get loss from local
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.recv_backward_buffer[model_chunk_id].append(None)
|
||||
return None, []
|
||||
|
||||
################
|
||||
# chunk = 1 & not is_first_stage
|
||||
# self.comm.recv_backward recv bwd from prev stage;
|
||||
# chunk = 1 & not first stage
|
||||
# recv_backward recv bwd from prev stage;
|
||||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
|
@ -296,14 +271,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 && is_last_stage
|
||||
# hold y on local_send_forward_buffer
|
||||
# do nothing; hold y on local_send_forward_buffer
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(output_tensor)
|
||||
return []
|
||||
|
||||
################
|
||||
|
@ -312,15 +285,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
|
||||
# send_metadata=self.send_tensor_metadata
|
||||
# self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
else:
|
||||
################
|
||||
# chunk = 1 && is_first_stage
|
||||
# do nothing; cause you are the last chunk on last stage;
|
||||
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
@ -331,9 +303,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(output_tensor, prev_rank)
|
||||
# send_metadata=self.send_tensor_metadata
|
||||
# self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
|
@ -355,7 +326,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
################
|
||||
# chunk = 0 && is_first_stage
|
||||
# do nothing; cause u are the first chunk in first stage; bwd end
|
||||
# send input_tensor_grad to local buffer;
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
@ -365,21 +335,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Send dx to PREV stage;
|
||||
################
|
||||
else:
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
|
||||
# send_metadata=self.send_grad_metadata
|
||||
return send_handles
|
||||
|
||||
# bwd chunk1 is left V;
|
||||
else:
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}")
|
||||
################
|
||||
# chunk = 1 && is_last_stage
|
||||
# hold dy to local_send_bwd_buffer;
|
||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
self.local_send_backward_buffer.append(input_tensor_grad)
|
||||
return []
|
||||
|
||||
################
|
||||
|
@ -387,14 +355,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Send dx to NEXT stage;
|
||||
################
|
||||
else:
|
||||
print(
|
||||
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
|
||||
)
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
|
||||
# send_metadata=self.send_grad_metadata
|
||||
return send_handles
|
||||
|
||||
def forward_step(
|
||||
|
@ -519,20 +482,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
outputs: Optional[List[Any]] = None,
|
||||
):
|
||||
# Step1: recv fwd
|
||||
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# # first layer
|
||||
# input_obj = input_obj
|
||||
# else:
|
||||
# # other layer
|
||||
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
|
||||
# # print(f"recv input_obj {input_obj}")
|
||||
# _wait_p2p(wait_handles)
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = input_obj
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = input_obj
|
||||
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
# is last stage; recv from local
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_obj = self.local_send_forward_buffer.pop(0)
|
||||
# not last stage; recv from next
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
|
@ -555,8 +518,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
# Step3: send fwd
|
||||
# add output to send_fwd_buffer
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
if model_chunk_id == 0:
|
||||
# is last stage; send to local_send_forward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.local_send_backward_buffer.append(output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
|
@ -569,14 +542,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
# Step1: recv bwd
|
||||
# # not first stage and chunk 1
|
||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# output_tensor_grad, recv_bwd_handles = None, []
|
||||
# # print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
# else:
|
||||
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# # print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
if model_chunk_id == 0:
|
||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
|
||||
|
||||
|
@ -593,11 +572,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
# _wait_p2p(recv_bwd_handles)
|
||||
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
||||
# Step2: bwd step
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
|
||||
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
|
@ -609,8 +584,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
|
||||
|
||||
# Step3: send bwd
|
||||
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
if model_chunk_id == 0:
|
||||
# do nothing; end of bwd;
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
# save input_object_grad to send_backward_buffer
|
||||
else:
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
else:
|
||||
# send to local_send_backward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_backward_buffer.append(input_object_grad)
|
||||
# send to next
|
||||
else:
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
|
||||
def schedule_w(
|
||||
self,
|
||||
|
@ -644,9 +631,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
):
|
||||
it = self.it
|
||||
# while we still have schedules_node in self.schedules
|
||||
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
print(f"it {it}; scheduled_node {scheduled_node};")
|
||||
print(
|
||||
f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};"
|
||||
)
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
|
|
|
@ -486,7 +486,7 @@ def test_run_fwd_bwd_base(
|
|||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||
],
|
||||
# stage 2
|
||||
[
|
||||
|
@ -547,7 +547,7 @@ def test_run_fwd_bwd_base(
|
|||
# init model and input
|
||||
num_layers = 8
|
||||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
# print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||
|
||||
|
@ -578,9 +578,9 @@ def test_run_fwd_bwd_base(
|
|||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
local_chunk.append(sub_model)
|
||||
print(
|
||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
# print(
|
||||
# f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
# )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
scheduler.run_forward_backward(
|
||||
|
|
Loading…
Reference in New Issue