diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 9445a4dcd..09ea4000c 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -89,8 +89,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.input_tensors = [[], []] self.output_tensors = [[], []] - # x & y & dy buffer for schedule w - self.input_tensors_dw = [[], []] + # y & dy buffer for schedule w self.output_tensors_dw = [[], []] self.output_tensors_grad_dw = [[], []] @@ -111,8 +110,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): assert len(self.input_tensors[1]) == 0 assert len(self.output_tensors[0]) == 0 assert len(self.output_tensors[1]) == 0 - assert len(self.input_tensors_dw[0]) == 0 - assert len(self.input_tensors_dw[1]) == 0 assert len(self.output_tensors_dw[0]) == 0 assert len(self.output_tensors_dw[1]) == 0 assert len(self.output_tensors_grad_dw[0]) == 0 @@ -528,7 +525,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ): @@ -555,7 +551,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - for k, v in input_obj.items(): + # for k, v in input_obj.items(): + # if v.requires_grad: + # output_obj_.append(output_obj[k]) + # output_obj_grad_.append(output_obj_grad[k]) + for k, v in output_obj.items(): if v.requires_grad: output_obj_.append(output_obj[k]) output_obj_grad_.append(output_obj_grad[k]) @@ -636,10 +636,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # add input and output object for backward b if input_obj is not None: self.input_tensors[model_chunk_id].append(input_obj) - self.input_tensors_dw[model_chunk_id].append(input_obj) else: self.input_tensors[model_chunk_id].append(micro_batch) - self.input_tensors_dw[model_chunk_id].append(micro_batch) # for bwd b&w, we only need the graph(grad_fn) of output_obj # Do not deallocate loss, deallocate other output_obj; @@ -760,7 +758,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # get y & dy from buffer - input_obj = self.input_tensors_dw[model_chunk_id].pop(0) output_obj = self.output_tensors_dw[model_chunk_id].pop(0) output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) @@ -768,7 +765,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, - input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_obj_grad, ) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index de18ae39b..6fa04d0a3 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -596,7 +596,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): batch_size = test_config["batch_size"] num_layers = 8 assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 4096 + in_dim = out_dim = 1024 before_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)