|
|
|
@ -614,14 +614,24 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Step3: release_tensor_data output for bwd b & w; (do not detach output)
|
|
|
|
|
deallocate_output_obj = tree_map(clone, output_obj)
|
|
|
|
|
# Step3:
|
|
|
|
|
# 3-1:detach output; detach output for send fwd;
|
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# We should not detach bwd LOSS
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
# detach output
|
|
|
|
|
detached_output_obj = tree_map(detach, output_obj)
|
|
|
|
|
# 3-2 clone output
|
|
|
|
|
output_obj = tree_map(clone, output_obj)
|
|
|
|
|
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
|
|
|
|
|
output_obj = tree_map(clone, output_obj)
|
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# We should not release_tensor_data bwd LOSS
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
# release_tensor_data output
|
|
|
|
|
tree_map(release_tensor_data, deallocate_output_obj)
|
|
|
|
|
tree_map(release_tensor_data, output_obj)
|
|
|
|
|
|
|
|
|
|
# add input and output object for backward b
|
|
|
|
|
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
|
|
|
@ -629,33 +639,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
|
|
|
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
|
|
|
|
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
|
|
|
|
self.output_tensors[model_chunk_id].append(output_obj)
|
|
|
|
|
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
|
|
|
|
else:
|
|
|
|
|
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
|
|
|
|
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
|
|
|
|
|
|
|
|
|
# Step4: detach output for send fwd;
|
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# We should not detach bwd LOSS
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
# detach output
|
|
|
|
|
output_obj = tree_map(detach, output_obj)
|
|
|
|
|
self.output_tensors[model_chunk_id].append(output_obj)
|
|
|
|
|
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
|
|
|
|
|
|
|
|
|
# add output to send_fwd_buffer
|
|
|
|
|
if model_chunk_id == 0: # chunk 0
|
|
|
|
|
# is last stage; send to local_send_forward_buffer
|
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
|
|
|
self.local_send_forward_buffer.append(output_obj)
|
|
|
|
|
self.local_send_forward_buffer.append(detached_output_obj)
|
|
|
|
|
else:
|
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
|
|
|
|
else: # chunk 1
|
|
|
|
|
# is first stage; end of fwd; do nothing
|
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
|
|
|
|
|
|
|
|
|
def schedule_b(
|
|
|
|
|
self,
|
|
|
|
|