[fix] fix zbv wait_handle

pull/6114/head
duanjunwen 2024-11-15 07:56:14 +00:00
parent 5c2ebbfd48
commit cf86c1b1c5
1 changed files with 27 additions and 19 deletions

View File

@ -115,10 +115,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_grad_dw = [[], []]
# buffer for communication
self.send_forward_buffer = [[], []]
self.recv_forward_buffer = [[], []]
self.send_backward_buffer = [[], []]
self.recv_backward_buffer = [[], []]
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_forward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_backward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
# y buffer for local send fwd
self.local_send_forward_buffer = []
@ -257,7 +263,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles
else:
@ -280,7 +286,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -316,7 +322,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles
else:
@ -339,7 +345,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -651,9 +657,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0:
# is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = None
input_obj = None # (tensor, wait_handle)
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
else:
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
@ -661,7 +670,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
# Here, let input_obj.requires_grad_()
# if input_obj is not None:
if not isinstance(input_obj, torch.Tensor):
@ -751,6 +762,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# chunk0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
@ -758,6 +772,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
@ -886,9 +903,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
@ -898,9 +912,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs,
)
elif scheduled_node.type == "B":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
@ -914,9 +925,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)