|
|
|
@ -108,6 +108,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# dy buffer for local send bwd
|
|
|
|
|
self.local_send_backward_buffer = []
|
|
|
|
|
|
|
|
|
|
def assert_buffer_empty(self):
|
|
|
|
|
# assert buuffer is empty at end
|
|
|
|
|
assert len(self.input_tensors[0]) == 0
|
|
|
|
|
assert len(self.input_tensors[1]) == 0
|
|
|
|
|
assert len(self.output_tensors[0]) == 0
|
|
|
|
|
assert len(self.output_tensors[1]) == 0
|
|
|
|
|
assert len(self.output_tensors_dw[0]) == 0
|
|
|
|
|
assert len(self.output_tensors_dw[1]) == 0
|
|
|
|
|
assert len(self.output_tensors_grad_dw[0]) == 0
|
|
|
|
|
assert len(self.output_tensors_grad_dw[1]) == 0
|
|
|
|
|
assert len(self.send_forward_buffer[0]) == 0
|
|
|
|
|
assert len(self.send_forward_buffer[1]) == 0
|
|
|
|
|
assert len(self.recv_forward_buffer[0]) == 0
|
|
|
|
|
assert len(self.recv_forward_buffer[1]) == 0
|
|
|
|
|
assert len(self.send_backward_buffer[0]) == 0
|
|
|
|
|
assert len(self.send_backward_buffer[1]) == 0
|
|
|
|
|
assert len(self.recv_backward_buffer[0]) == 0
|
|
|
|
|
assert len(self.recv_backward_buffer[1]) == 0
|
|
|
|
|
assert len(self.local_send_forward_buffer) == 0
|
|
|
|
|
assert len(self.local_send_backward_buffer) == 0
|
|
|
|
|
|
|
|
|
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
|
|
|
|
"""Load a batch from data iterator.
|
|
|
|
|
|
|
|
|
@ -546,7 +567,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
)
|
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# We should not detach bwd LOSS
|
|
|
|
|
detached_output_obj = output_obj.clone()
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
detached_output_obj = output_obj.clone().detach()
|
|
|
|
|
|
|
|
|
@ -555,7 +576,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if model_chunk_id == 0:
|
|
|
|
|
# is last stage; send to local_send_forward_buffer
|
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
|
|
|
detached_output_obj = detached_output_obj.detach()
|
|
|
|
|
self.local_send_forward_buffer.append(detached_output_obj)
|
|
|
|
|
else:
|
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
|
|
|
@ -816,4 +836,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.assert_buffer_empty()
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|