diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py index 0d083c610..f1ea3f61e 100644 --- a/colossalai/pipeline/schedule/v_schedule.py +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -12,8 +12,8 @@ class ScheduledNode: chunk: int stage: int minibatch: int - start_time: int - completion_time: int + # start_time: int + # completion_time: int rollback: bool = False diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 0fef29446..f2d33f7b5 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -176,6 +176,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): + self.recv_forward_buffer[model_chunk_id].append(None) return None, [] ################ @@ -188,6 +189,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # metadata_recv=self.tensor_metadata_recv # if self.enable_metadata_cache and self.tensor_metadata_recv is None: # self.tensor_metadata_recv = create_send_metadata(input_tensor) + self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles else: @@ -200,7 +202,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # if self.enable_metadata_cache and self.tensor_metadata_recv is None: # self.tensor_metadata_recv = create_send_metadata(input_tensor) - + self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, [] ################ @@ -214,7 +216,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # metadata_recv=self.tensor_metadata_recv # if self.enable_metadata_cache and self.tensor_metadata_recv is None: # self.tensor_metadata_recv = create_send_metadata(input_tensor) - + self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: @@ -240,6 +242,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.local_send_backward_buffer.pop(0) # if self.enable_metadata_cache and self.grad_metadata_recv is None: # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, [] ################ @@ -252,6 +255,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # metadata_recv=self.grad_metadata_recv # if self.enable_metadata_cache and self.grad_metadata_recv is None: # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles else: @@ -261,6 +265,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): + self.recv_backward_buffer[model_chunk_id].append(None) return None, [] ################ @@ -268,16 +273,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # self.comm.recv_backward recv bwd from prev stage; ################ else: - prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) - + # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}") # metadata_recv=self.grad_metadata_recv # if self.enable_metadata_cache and self.grad_metadata_recv is None: # self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles - def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List: + def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: """Sends the input tensor to the next stage in pipeline. For ZBV. @@ -291,6 +296,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) if model_chunk_id == 0: ################ # chunk = 0 && is_last_stage @@ -330,7 +336,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # self.send_tensor_metadata = not self.enable_metadata_cache return send_handles - def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List: + def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: """Sends the gradient tensor to the previous stage in pipeline. For ZBV. @@ -359,6 +365,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Send dx to PREV stage; ################ else: + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) prev_rank = self.stage_manager.get_prev_rank() send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) # send_metadata=self.send_grad_metadata @@ -371,6 +378,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # hold dy to local_send_bwd_buffer; ################ if self.stage_manager.is_last_stage(ignore_chunk=True): + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) self.local_send_backward_buffer.append(input_tensor_grad) return [] @@ -379,6 +387,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Send dx to NEXT stage; ################ else: + print( + f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}" + ) + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) next_rank = self.stage_manager.get_next_rank() # print(f"send bwd input_tensor_grad {input_tensor_grad}") send_handles = self.comm.send_backward(input_tensor_grad, next_rank) @@ -413,6 +425,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Only attention_mask from micro_batch is used with self.stage_manager.switch_model_chunk_id(model_chunk_id): + # fwd calculate output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -463,6 +476,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # commom bwd step # print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}") # BUG:output_obj_grad is None + # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n") torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True ) @@ -505,14 +519,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs: Optional[List[Any]] = None, ): # Step1: recv fwd + # if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + # # first layer + # input_obj = input_obj + # else: + # # other layer + # input_obj, wait_handles = self.recv_forward(model_chunk_id) + # # print(f"recv input_obj {input_obj}") + # _wait_p2p(wait_handles) + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # first layer input_obj = input_obj + self.recv_forward_buffer[model_chunk_id].pop(0) # pop none else: - # other layer - input_obj, wait_handles = self.recv_forward(model_chunk_id) - # print(f"recv input_obj {input_obj}") - _wait_p2p(wait_handles) + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + # Step2: fwd step output_obj = self.forward_step( model_chunk=model_chunk, @@ -522,6 +543,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) + # print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}") # add input and output object for backward b @@ -532,7 +554,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.output_tensors_dw[model_chunk_id].append(output_obj) # Step3: send fwd - send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj) + # add output to send_fwd_buffer + self.send_forward_buffer[model_chunk_id].append(output_obj) + # send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj) def schedule_b( self, @@ -545,17 +569,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # output_obj_grad: Optional[dict], ): # Step1: recv bwd - # not first stage and chunk 1 - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - output_tensor_grad, recv_bwd_handles = None, [] - # print(f"recv output_tensor_grad {output_tensor_grad}") - else: - output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) - # print(f"recv output_tensor_grad {output_tensor_grad}") + # # not first stage and chunk 1 + # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # output_tensor_grad, recv_bwd_handles = None, [] + # # print(f"recv output_tensor_grad {output_tensor_grad}") + # else: + # output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) + # # print(f"recv output_tensor_grad {output_tensor_grad}") + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + + # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n") # get input and output object from buffer; - input_obj = self.input_tensors[model_chunk_id].pop() - output_obj = self.output_tensors[model_chunk_id].pop() + input_obj = self.input_tensors[model_chunk_id].pop(0) + output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -565,9 +592,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - _wait_p2p(recv_bwd_handles) + # _wait_p2p(recv_bwd_handles) # print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}") # Step2: bwd step + + # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}") + input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, @@ -576,23 +606,23 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj=output_obj, output_obj_grad=output_tensor_grad, ) - print(f"input_object_grad {input_object_grad}") + # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}") # Step3: send bwd - send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad) + # send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad) + self.send_backward_buffer[model_chunk_id].append(input_object_grad) def schedule_w( self, scheduled_node, - non_w_pending, model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, ): # get y & dy from buffer - output_obj = self.output_tensors_dw[model_chunk_id].pop() - output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop() + output_obj = self.output_tensors_dw[model_chunk_id].pop(0) + output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) self.backward_w_step( model_chunk=model_chunk, @@ -605,6 +635,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): def run_forward_backward( self, model_chunk: Union[ModuleList, Module], + input_obj: Optional[dict], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, @@ -615,19 +646,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # while we still have schedules_node in self.schedules while it < len(self.schedules): scheduled_node = self.schedules[it] + print(f"it {it}; scheduled_node {scheduled_node};") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication if scheduled_node.type == "RECV_FORWARD": - self.recv_forward() + self.recv_forward(scheduled_node.chunk) elif scheduled_node.type == "RECV_BACKWARD": - self.recv_backward() + self.recv_backward(scheduled_node.chunk) elif scheduled_node.type == "SEND_FORWARD": - self.send_forward() + self.send_forward(scheduled_node.chunk) elif scheduled_node.type == "SEND_BACKWARD": - self.send_backward() - elif scheduled_node.type == "F": - self.schedule_f() + self.send_backward(scheduled_node.chunk) + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + input_obj=input_obj, + criterion=criterion, + accum_loss=return_loss, + outputs=return_outputs, + ) elif scheduled_node.type == "B": - self.schedule_b() + self.schedule_b( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + ) elif scheduled_node.type == "W": - self.schedule_w() + self.schedule_w( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + ) + it += 1 diff --git a/tests/test_pipeline/test_schedule/test_dx_dw.py b/tests/test_pipeline/test_schedule/test_zerobubble_poc.py similarity index 99% rename from tests/test_pipeline/test_schedule/test_dx_dw.py rename to tests/test_pipeline/test_schedule/test_zerobubble_poc.py index 1ade7d45a..ac7ea3f9a 100644 --- a/tests/test_pipeline/test_schedule/test_dx_dw.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_poc.py @@ -1176,17 +1176,8 @@ def model_chunk_dx_dw_comm_interleaved( print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") -def run_fwd_bwd( - rank: int, - world_size: int, - port: int, -): - pass - - @rerun_if_address_is_in_use() def test_dx_dw_dist(): - spawn( model_chunk_dx_dw_comm_interleaved, nprocs=4, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index b0927c0c4..a8502c2af 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -8,6 +8,7 @@ from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -34,6 +35,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: return num_params, num_params_trainable +# Test baseline; An 8 layer MLP do Zerobubble Pipeline on 4 node pp group; def test_zerobubble_pipeline_base( rank: int, world_size: int, @@ -427,18 +429,187 @@ def test_zerobubble_pipeline_base( assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad) +# Test run_forward_backward with baseline; +def test_run_fwd_bwd_base( + rank: int, + world_size: int, + port: int, +): + # init dist + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + rank = dist.get_rank() + pp_size = world_size + pg_mesh = ProcessGroupMesh(pp_size) + + # stage_manager + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size) + + # schedule list + zbv_schedule = [ + # stage 0 + [ + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + ], + # stage 1 + [ + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + ], + # stage 2 + [ + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing + ], + # stage 3 + [ + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), + ], + ] + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule[rank], + stage_manager=stage_manager, + num_model_chunks=pp_size, + num_microbatch=1, + overlap_p2p=False, + ) + + # loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + num_layers = 8 + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) + + input0.clone() + deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.Sequential().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + scheduler.run_forward_backward( + model_chunk=local_chunk, + input_obj=input0, + data_iter=None, + criterion=criterion, + optimizer=None, + return_loss=None, + return_outputs=None, + ) + + # @pytest.mark.dist # @pytest.mark.parametrize("num_microbatch", [4]) # @pytest.mark.parametrize("batch_size", [4]) # @pytest.mark.parametrize("num_model_chunk", [2]) @rerun_if_address_is_in_use() def test_pp(): + # spawn( + # test_zerobubble_pipeline_base, + # nprocs=4, + # ) + spawn( - test_zerobubble_pipeline_base, + test_run_fwd_bwd_base, nprocs=4, ) if __name__ == "__main__": - test_pp()