diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py
index 9445a4dcd..09ea4000c 100644
--- a/colossalai/pipeline/schedule/zero_bubble_pp.py
+++ b/colossalai/pipeline/schedule/zero_bubble_pp.py
@@ -89,8 +89,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
         self.input_tensors = [[], []]
         self.output_tensors = [[], []]
 
-        # x & y & dy buffer for schedule w
-        self.input_tensors_dw = [[], []]
+        # y & dy buffer for schedule w
         self.output_tensors_dw = [[], []]
         self.output_tensors_grad_dw = [[], []]
 
@@ -111,8 +110,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
         assert len(self.input_tensors[1]) == 0
         assert len(self.output_tensors[0]) == 0
         assert len(self.output_tensors[1]) == 0
-        assert len(self.input_tensors_dw[0]) == 0
-        assert len(self.input_tensors_dw[1]) == 0
         assert len(self.output_tensors_dw[0]) == 0
         assert len(self.output_tensors_dw[1]) == 0
         assert len(self.output_tensors_grad_dw[0]) == 0
@@ -528,7 +525,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
         model_chunk: Union[ModuleList, Module],
         model_chunk_id: int,
         optimizer: OptimizerWrapper,
-        input_obj: Optional[dict],
         output_obj: Union[dict, torch.Tensor],
         output_obj_grad: Optional[dict],
     ):
@@ -555,7 +551,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
             output_obj_.append(output_obj)  # LOSS
             output_obj_grad_.append(None)  # None
         else:
-            for k, v in input_obj.items():
+            # for k, v in input_obj.items():
+            #     if v.requires_grad:
+            #         output_obj_.append(output_obj[k])
+            #         output_obj_grad_.append(output_obj_grad[k])
+            for k, v in output_obj.items():
                 if v.requires_grad:
                     output_obj_.append(output_obj[k])
                     output_obj_grad_.append(output_obj_grad[k])
@@ -636,10 +636,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
         # add input and output object for backward b
         if input_obj is not None:
             self.input_tensors[model_chunk_id].append(input_obj)
-            self.input_tensors_dw[model_chunk_id].append(input_obj)
         else:
             self.input_tensors[model_chunk_id].append(micro_batch)
-            self.input_tensors_dw[model_chunk_id].append(micro_batch)
 
         # for bwd b&w, we only need the graph(grad_fn) of output_obj
         # Do not deallocate loss, deallocate other output_obj;
@@ -760,7 +758,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
         """
 
         # get y & dy from buffer
-        input_obj = self.input_tensors_dw[model_chunk_id].pop(0)
         output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
         output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
 
@@ -768,7 +765,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
             model_chunk=model_chunk,
             model_chunk_id=model_chunk_id,
             optimizer=optimizer,
-            input_obj=input_obj,
             output_obj=output_obj,
             output_obj_grad=output_obj_grad,
         )
diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
index de18ae39b..6fa04d0a3 100644
--- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
+++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
@@ -596,7 +596,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
     batch_size = test_config["batch_size"]
     num_layers = 8
     assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
-    in_dim = out_dim = 4096
+    in_dim = out_dim = 1024
     before_init_memory = torch.cuda.memory_allocated() / 1024**3
     print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
     model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)