diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py
index 3ab7907b9..3c19b6027 100644
--- a/colossalai/pipeline/schedule/zero_bubble_pp.py
+++ b/colossalai/pipeline/schedule/zero_bubble_pp.py
@@ -568,29 +568,31 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
             outputs=outputs,
         )
 
+        detached_output_obj = output_obj.clone()
+        detached_output_obj.requires_grad_()
+
         # Step3: send fwd
         # add output to send_fwd_buffer
         if model_chunk_id == 0:
             # is last stage; send to local_send_forward_buffer
             if self.stage_manager.is_last_stage(ignore_chunk=True):
-                self.local_send_forward_buffer.append(output_obj)
+                self.local_send_forward_buffer.append(detached_output_obj)
             else:
-                self.send_forward_buffer[model_chunk_id].append(output_obj)
+                self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
         else:
             # is first stage; end of fwd; append LOSS to local_send_backward_buffer
             if self.stage_manager.is_first_stage(ignore_chunk=True):
-                self.local_send_backward_buffer.append(output_obj)
+                self.local_send_backward_buffer.append(detached_output_obj)
             else:
-                self.send_forward_buffer[model_chunk_id].append(output_obj)
+                self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
 
         # add input and output object for backward b
         self.input_tensors[model_chunk_id].append(input_obj)
         # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
-        detached_output_obj = output_obj.clone()
-        deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
-        self.output_tensors[model_chunk_id].append(detached_output_obj)
+        deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
+        self.output_tensors[model_chunk_id].append(output_obj)
         # add output object for backward w
-        self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
+        self.output_tensors_dw[model_chunk_id].append(output_obj)
 
     def schedule_b(
         self,