mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix zbv wait_handle
parent
5c2ebbfd48
commit
cf86c1b1c5
|
@ -115,10 +115,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self.output_tensors_grad_dw = [[], []]
|
||||
|
||||
# buffer for communication
|
||||
self.send_forward_buffer = [[], []]
|
||||
self.recv_forward_buffer = [[], []]
|
||||
self.send_backward_buffer = [[], []]
|
||||
self.recv_backward_buffer = [[], []]
|
||||
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||
self.recv_forward_buffer = [
|
||||
[],
|
||||
[],
|
||||
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
||||
self.recv_backward_buffer = [
|
||||
[],
|
||||
[],
|
||||
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
||||
|
||||
# y buffer for local send fwd
|
||||
self.local_send_forward_buffer = []
|
||||
|
@ -257,7 +263,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
else:
|
||||
|
@ -280,7 +286,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
|
@ -316,7 +322,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
else:
|
||||
|
@ -339,7 +345,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
||||
return wait_handles
|
||||
|
||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
|
@ -651,9 +657,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
if model_chunk_id == 0:
|
||||
# is first stage; get input from microbatch
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = None
|
||||
input_obj = None # (tensor, wait_handle)
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
for h in input_obj[1]:
|
||||
h.wait()
|
||||
input_obj = input_obj[0]
|
||||
else:
|
||||
# is last stage; recv from local
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
|
@ -661,7 +670,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# not last stage; recv from next
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
for h in input_obj[1]:
|
||||
h.wait()
|
||||
input_obj = input_obj[0]
|
||||
# Here, let input_obj.requires_grad_()
|
||||
# if input_obj is not None:
|
||||
if not isinstance(input_obj, torch.Tensor):
|
||||
|
@ -751,6 +762,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
for h in output_tensor_grad[1]:
|
||||
h.wait()
|
||||
output_tensor_grad = output_tensor_grad[0]
|
||||
else:
|
||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
|
@ -758,6 +772,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
for h in output_tensor_grad[1]:
|
||||
h.wait()
|
||||
output_tensor_grad = output_tensor_grad[0]
|
||||
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
|
@ -886,9 +903,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
@ -898,9 +912,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
outputs=outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
self.schedule_b(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
@ -914,9 +925,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
|
Loading…
Reference in New Issue