diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 9771277e2..ae35bc967 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -614,14 +614,24 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) - # Step3: release_tensor_data output for bwd b & w; (do not detach output) - deallocate_output_obj = tree_map(clone, output_obj) + # Step3: + # 3-1:detach output; detach output for send fwd; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not detach bwd LOSS + pass + else: + # detach output + detached_output_obj = tree_map(detach, output_obj) + # 3-2 clone output + output_obj = tree_map(clone, output_obj) + # 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output) + output_obj = tree_map(clone, output_obj) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # We should not release_tensor_data bwd LOSS pass else: # release_tensor_data output - tree_map(release_tensor_data, deallocate_output_obj) + tree_map(release_tensor_data, output_obj) # add input and output object for backward b self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) @@ -629,33 +639,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # for bwd b&w, we only need the graph(grad_fn) of output_obj # Do not release_tensor_data loss, release_tensor_data other output_obj; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - self.output_tensors[model_chunk_id].append(deallocate_output_obj) - self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj) + self.output_tensors[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(output_obj) else: - self.output_tensors[model_chunk_id].append(deallocate_output_obj) - self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj) - - # Step4: detach output for send fwd; - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # We should not detach bwd LOSS - pass - else: - # detach output - output_obj = tree_map(detach, output_obj) + self.output_tensors[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(output_obj) # add output to send_fwd_buffer if model_chunk_id == 0: # chunk 0 # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): - self.local_send_forward_buffer.append(output_obj) + self.local_send_forward_buffer.append(detached_output_obj) else: - self.send_forward_buffer[model_chunk_id].append(output_obj) + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) else: # chunk 1 # is first stage; end of fwd; do nothing if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: - self.send_forward_buffer[model_chunk_id].append(output_obj) + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) def schedule_b( self,