diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index f2d33f7b5..da5320cf3 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -176,7 +176,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): - self.recv_forward_buffer[model_chunk_id].append(None) return None, [] ################ @@ -186,24 +185,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) - # metadata_recv=self.tensor_metadata_recv - # if self.enable_metadata_cache and self.tensor_metadata_recv is None: - # self.tensor_metadata_recv = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles else: ################ # chunk = 1 & is_last_stage - # get y from local_send_forward_buffer as input + # do nothing; cause u get y from local_send_forward_buffer in schedule f ################ if self.stage_manager.is_last_stage(ignore_chunk=True): - input_tensor = self.local_send_forward_buffer.pop(0) - - # if self.enable_metadata_cache and self.tensor_metadata_recv is None: - # self.tensor_metadata_recv = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) - return input_tensor, [] + return None, [] ################ # chunk = 1 & not is_last_stage @@ -212,10 +203,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() input_tensor, wait_handles = self.comm.recv_forward(next_rank) - - # metadata_recv=self.tensor_metadata_recv - # if self.enable_metadata_cache and self.tensor_metadata_recv is None: - # self.tensor_metadata_recv = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles @@ -236,14 +223,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # bwd chunk0 is right V; ################ # chunk = 0 & is_last_stage - # get dy from local recv_bwd_buffer + # do nothing; Already get dy from local_send_backward_buffer in schedule b ################ if self.stage_manager.is_last_stage(ignore_chunk=True): - output_tensor_grad = self.local_send_backward_buffer.pop(0) - # if self.enable_metadata_cache and self.grad_metadata_recv is None: - # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) - return output_tensor_grad, [] + return None, [] ################ # chunk = 0 & not is_last_stage @@ -252,9 +235,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) - # metadata_recv=self.grad_metadata_recv - # if self.enable_metadata_cache and self.grad_metadata_recv is None: - # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles @@ -265,20 +245,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): - self.recv_backward_buffer[model_chunk_id].append(None) return None, [] ################ - # chunk = 1 & not is_first_stage - # self.comm.recv_backward recv bwd from prev stage; + # chunk = 1 & not first stage + # recv_backward recv bwd from prev stage; ################ else: prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) - # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}") - # metadata_recv=self.grad_metadata_recv - # if self.enable_metadata_cache and self.grad_metadata_recv is None: - # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles @@ -296,14 +271,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): - output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) if model_chunk_id == 0: ################ # chunk = 0 && is_last_stage - # hold y on local_send_forward_buffer + # do nothing; hold y on local_send_forward_buffer ################ if self.stage_manager.is_last_stage(ignore_chunk=True): - self.local_send_forward_buffer.append(output_tensor) return [] ################ @@ -312,15 +285,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ################ else: next_rank = self.stage_manager.get_next_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) - # send_metadata=self.send_tensor_metadata - # self.send_tensor_metadata = not self.enable_metadata_cache return send_handles else: ################ # chunk = 1 && is_first_stage - # do nothing; cause you are the last chunk on last stage; + # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return [] @@ -331,9 +303,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ################ else: prev_rank = self.stage_manager.get_prev_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_tensor, prev_rank) - # send_metadata=self.send_tensor_metadata - # self.send_tensor_metadata = not self.enable_metadata_cache return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -355,7 +326,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ################ # chunk = 0 && is_first_stage # do nothing; cause u are the first chunk in first stage; bwd end - # send input_tensor_grad to local buffer; ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return [] @@ -365,21 +335,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Send dx to PREV stage; ################ else: - input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) prev_rank = self.stage_manager.get_prev_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) - # send_metadata=self.send_grad_metadata return send_handles # bwd chunk1 is left V; else: + # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}") ################ # chunk = 1 && is_last_stage - # hold dy to local_send_bwd_buffer; + # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; ################ if self.stage_manager.is_last_stage(ignore_chunk=True): - input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) - self.local_send_backward_buffer.append(input_tensor_grad) return [] ################ @@ -387,14 +355,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Send dx to NEXT stage; ################ else: - print( - f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}" - ) - input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) next_rank = self.stage_manager.get_next_rank() - # print(f"send bwd input_tensor_grad {input_tensor_grad}") + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward(input_tensor_grad, next_rank) - # send_metadata=self.send_grad_metadata return send_handles def forward_step( @@ -519,20 +482,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs: Optional[List[Any]] = None, ): # Step1: recv fwd - # if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # # first layer - # input_obj = input_obj - # else: - # # other layer - # input_obj, wait_handles = self.recv_forward(model_chunk_id) - # # print(f"recv input_obj {input_obj}") - # _wait_p2p(wait_handles) + if model_chunk_id == 0: + # is first stage; get input from func param + if self.stage_manager.is_first_stage(ignore_chunk=True): + input_obj = input_obj + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) - if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = input_obj - self.recv_forward_buffer[model_chunk_id].pop(0) # pop none else: - input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + # is last stage; recv from local + if self.stage_manager.is_last_stage(ignore_chunk=True): + input_obj = self.local_send_forward_buffer.pop(0) + # not last stage; recv from next + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Step2: fwd step output_obj = self.forward_step( @@ -555,8 +518,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Step3: send fwd # add output to send_fwd_buffer - self.send_forward_buffer[model_chunk_id].append(output_obj) - # send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj) + if model_chunk_id == 0: + # is last stage; send to local_send_forward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_forward_buffer.append(output_obj) + else: + self.send_forward_buffer[model_chunk_id].append(output_obj) + else: + # is first stage; end of fwd; append LOSS to local_send_backward_buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(output_obj) + else: + self.send_forward_buffer[model_chunk_id].append(output_obj) def schedule_b( self, @@ -569,14 +542,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # output_obj_grad: Optional[dict], ): # Step1: recv bwd - # # not first stage and chunk 1 - # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # output_tensor_grad, recv_bwd_handles = None, [] - # # print(f"recv output_tensor_grad {output_tensor_grad}") - # else: - # output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) - # # print(f"recv output_tensor_grad {output_tensor_grad}") - output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + if model_chunk_id == 0: + # chunk0 is last stage; recv output_grad from local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # chunk 0 not last stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + else: + # chunk1, is first stage; recv LOSS from local send bwd buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # chunk1, not first stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n") @@ -593,11 +572,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) # _wait_p2p(recv_bwd_handles) - # print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}") # Step2: bwd step - - # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}") - input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, @@ -609,8 +584,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}") # Step3: send bwd - # send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad) - self.send_backward_buffer[model_chunk_id].append(input_object_grad) + if model_chunk_id == 0: + # do nothing; end of bwd; + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + # save input_object_grad to send_backward_buffer + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + else: + # send to local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(input_object_grad) + # send to next + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) def schedule_w( self, @@ -644,9 +631,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ): it = self.it # while we still have schedules_node in self.schedules + # print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n") while it < len(self.schedules): scheduled_node = self.schedules[it] - print(f"it {it}; scheduled_node {scheduled_node};") + print( + f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};" + ) if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication if scheduled_node.type == "RECV_FORWARD": diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index a8502c2af..fe8dd6c36 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -486,7 +486,7 @@ def test_run_fwd_bwd_base( ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), ], # stage 2 [ @@ -547,7 +547,7 @@ def test_run_fwd_bwd_base( # init model and input num_layers = 8 in_dim = out_dim = 8 - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + # print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) @@ -578,9 +578,9 @@ def test_run_fwd_bwd_base( for idx, sub_model in enumerate(model.layers): if idx == 3 or idx == 4: local_chunk.append(sub_model) - print( - f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" - ) + # print( + # f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + # ) torch.cuda.synchronize() scheduler.run_forward_backward(