Browse Source

[fix] fix detach_output_obj clone;

pull/6065/head
duanjunwen 2 months ago
parent
commit
7e6f793c51
  1. 6
      colossalai/pipeline/schedule/zero_bubble_pp.py

6
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -622,10 +622,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
# detach output # detach output
detached_output_obj = tree_map(detach, output_obj) detached_output_obj = tree_map(detach, output_obj)
# 3-2 clone output # 3-2 clone detached_output_obj
output_obj = tree_map(clone, output_obj) detached_output_obj = tree_map(clone, detached_output_obj)
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output) # 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
# output_obj = tree_map(clone, output_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not release_tensor_data bwd LOSS # We should not release_tensor_data bwd LOSS
pass pass

Loading…
Cancel
Save