[fix] fix detach clone release order;

pull/6065/head
duanjunwen 2024-09-23 04:00:24 +00:00
parent da3220f48c
commit c114d1429a
1 changed files with 20 additions and 18 deletions

View File

@ -614,14 +614,24 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs, outputs=outputs,
) )
# Step3: release_tensor_data output for bwd b & w; (do not detach output) # Step3:
deallocate_output_obj = tree_map(clone, output_obj) # 3-1:detach output; detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
# detach output
detached_output_obj = tree_map(detach, output_obj)
# 3-2 clone output
output_obj = tree_map(clone, output_obj)
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
output_obj = tree_map(clone, output_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not release_tensor_data bwd LOSS # We should not release_tensor_data bwd LOSS
pass pass
else: else:
# release_tensor_data output # release_tensor_data output
tree_map(release_tensor_data, deallocate_output_obj) tree_map(release_tensor_data, output_obj)
# add input and output object for backward b # add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
@ -629,33 +639,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# for bwd b&w, we only need the graph(grad_fn) of output_obj # for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj; # Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(deallocate_output_obj) self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj) self.output_tensors_dw[model_chunk_id].append(output_obj)
else: else:
self.output_tensors[model_chunk_id].append(deallocate_output_obj) self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj) self.output_tensors_dw[model_chunk_id].append(output_obj)
# Step4: detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
# detach output
output_obj = tree_map(detach, output_obj)
# add output to send_fwd_buffer # add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0 if model_chunk_id == 0: # chunk 0
# is last stage; send to local_send_forward_buffer # is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(output_obj) self.local_send_forward_buffer.append(detached_output_obj)
else: else:
self.send_forward_buffer[model_chunk_id].append(output_obj) self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else: # chunk 1 else: # chunk 1
# is first stage; end of fwd; do nothing # is first stage; end of fwd; do nothing
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
pass pass
else: else:
self.send_forward_buffer[model_chunk_id].append(output_obj) self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
def schedule_b( def schedule_b(
self, self,