|
|
|
@ -64,10 +64,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
# P2PMeta cache |
|
|
|
|
self.enable_metadata_cache = enable_metadata_cache |
|
|
|
|
self.send_tensor_metadata = True |
|
|
|
|
self.send_grad_metadata = True |
|
|
|
|
self.tensor_metadata_recv = None |
|
|
|
|
self.grad_metadata_recv = None |
|
|
|
|
|
|
|
|
|
# check send_tensor_metadata, send_grad_metadata |
|
|
|
|
# pp4 as sample, we should follow this meta strategy |
|
|
|
|
# send_tensor_meta(fwd) send_grad_meta(bwd) |
|
|
|
|
# chunk0 | chunk1 chunk0 | chunk 1 |
|
|
|
|
# stage 0 T | F F | T |
|
|
|
|
# stage 1 T | T T | T |
|
|
|
|
# stage 2 T | T T | T |
|
|
|
|
# stage 3 F | T F | T |
|
|
|
|
if stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
self.send_tensor_metadata = [True, False] |
|
|
|
|
self.send_grad_metadata = [False, True] |
|
|
|
|
elif stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
|
self.send_tensor_metadata = [False, True] |
|
|
|
|
self.send_grad_metadata = [True, False] |
|
|
|
|
else: |
|
|
|
|
self.send_tensor_metadata = [True, True] |
|
|
|
|
self.send_grad_metadata = [True, True] |
|
|
|
|
|
|
|
|
|
# meta cache buffer |
|
|
|
|
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta] |
|
|
|
|
self.grad_metadata_recv = [None, None] |
|
|
|
|
|
|
|
|
|
# P2P communication |
|
|
|
|
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) |
|
|
|
@ -96,10 +114,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
self.output_tensors_grad_dw = [[], []] |
|
|
|
|
|
|
|
|
|
# buffer for communication |
|
|
|
|
self.send_forward_buffer = [[], []] |
|
|
|
|
self.recv_forward_buffer = [[], []] |
|
|
|
|
self.send_backward_buffer = [[], []] |
|
|
|
|
self.recv_backward_buffer = [[], []] |
|
|
|
|
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] |
|
|
|
|
self.recv_forward_buffer = [ |
|
|
|
|
[], |
|
|
|
|
[], |
|
|
|
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] |
|
|
|
|
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] |
|
|
|
|
self.recv_backward_buffer = [ |
|
|
|
|
[], |
|
|
|
|
[], |
|
|
|
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] |
|
|
|
|
|
|
|
|
|
# y buffer for local send fwd |
|
|
|
|
self.local_send_forward_buffer = [] |
|
|
|
@ -225,7 +249,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank; |
|
|
|
|
################# |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# return None, [] |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -235,12 +258,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else: |
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank() |
|
|
|
|
input_tensor, wait_handles = self.comm.recv_forward( |
|
|
|
|
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv |
|
|
|
|
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv is None: |
|
|
|
|
self.tensor_metadata_recv = create_send_metadata(input_tensor) |
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor) |
|
|
|
|
# return input_tensor, wait_handles |
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) |
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
@ -259,12 +281,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else: |
|
|
|
|
next_rank = self.stage_manager.get_next_rank() |
|
|
|
|
input_tensor, wait_handles = self.comm.recv_forward( |
|
|
|
|
next_rank, metadata_recv=self.tensor_metadata_recv |
|
|
|
|
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv is None: |
|
|
|
|
self.tensor_metadata_recv = create_send_metadata(input_tensor) |
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor) |
|
|
|
|
# return input_tensor, wait_handles |
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) |
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: |
|
|
|
@ -287,7 +308,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; Already get dy from local_send_backward_buffer in schedule b |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
|
# return None, [] |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -297,12 +317,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else: |
|
|
|
|
next_rank = self.stage_manager.get_next_rank() |
|
|
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward( |
|
|
|
|
next_rank, metadata_recv=self.grad_metadata_recv |
|
|
|
|
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv is None: |
|
|
|
|
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) |
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) |
|
|
|
|
# return output_tensor_grad, wait_handles |
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) |
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
@ -312,7 +331,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; get loss from local |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# return None, [] |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -322,12 +340,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else: |
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank() |
|
|
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward( |
|
|
|
|
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv |
|
|
|
|
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv is None: |
|
|
|
|
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) |
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) |
|
|
|
|
# return output_tensor_grad, wait_handles |
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) |
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: |
|
|
|
@ -349,6 +366,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; hold y on local_send_forward_buffer |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -359,9 +377,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
next_rank = self.stage_manager.get_next_rank() |
|
|
|
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) |
|
|
|
|
send_handles = self.comm.send_forward( |
|
|
|
|
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata |
|
|
|
|
output_object=output_tensor, |
|
|
|
|
next_rank=next_rank, |
|
|
|
|
send_metadata=self.send_tensor_metadata[model_chunk_id], |
|
|
|
|
) |
|
|
|
|
self.send_tensor_metadata = not self.enable_metadata_cache |
|
|
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return send_handles |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
@ -370,6 +390,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -380,9 +401,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank() |
|
|
|
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) |
|
|
|
|
send_handles = self.comm.send_forward( |
|
|
|
|
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata |
|
|
|
|
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
self.send_tensor_metadata = not self.enable_metadata_cache |
|
|
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return send_handles |
|
|
|
|
|
|
|
|
|
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: |
|
|
|
@ -405,6 +426,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; cause u are the first chunk in first stage; bwd end |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -415,9 +437,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank() |
|
|
|
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) |
|
|
|
|
send_handles = self.comm.send_backward( |
|
|
|
|
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata |
|
|
|
|
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
self.send_grad_metadata = not self.enable_metadata_cache |
|
|
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return send_handles |
|
|
|
|
|
|
|
|
|
# bwd chunk1 is left V; |
|
|
|
@ -427,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -437,9 +460,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
next_rank = self.stage_manager.get_next_rank() |
|
|
|
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) |
|
|
|
|
send_handles = self.comm.send_backward( |
|
|
|
|
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata |
|
|
|
|
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id] |
|
|
|
|
) |
|
|
|
|
self.send_grad_metadata = not self.enable_metadata_cache |
|
|
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache |
|
|
|
|
return send_handles |
|
|
|
|
|
|
|
|
|
def forward_step( |
|
|
|
@ -519,8 +542,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_obj_grad_ = [] |
|
|
|
|
|
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. |
|
|
|
|
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# return None |
|
|
|
|
|
|
|
|
|
# For loss backward; output_obj is loss; output_obj_grad should be None |
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
@ -633,9 +654,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if model_chunk_id == 0: |
|
|
|
|
# is first stage; get input from microbatch |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
input_obj = None |
|
|
|
|
input_obj = None # (tensor, wait_handle) |
|
|
|
|
else: |
|
|
|
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) |
|
|
|
|
for h in input_obj[1]: |
|
|
|
|
h.wait() |
|
|
|
|
input_obj = input_obj[0] |
|
|
|
|
else: |
|
|
|
|
# is last stage; recv from local |
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
@ -643,7 +667,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# not last stage; recv from next |
|
|
|
|
else: |
|
|
|
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) |
|
|
|
|
|
|
|
|
|
for h in input_obj[1]: |
|
|
|
|
h.wait() |
|
|
|
|
input_obj = input_obj[0] |
|
|
|
|
# Here, let input_obj.requires_grad_() |
|
|
|
|
# if input_obj is not None: |
|
|
|
|
if not isinstance(input_obj, torch.Tensor): |
|
|
|
@ -689,10 +715,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# Do not release_tensor_data loss, release_tensor_data other output_obj; |
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
self.output_tensors[model_chunk_id].append(output_obj) |
|
|
|
|
# self.output_tensors_dw[model_chunk_id].append(output_obj) |
|
|
|
|
else: |
|
|
|
|
self.output_tensors[model_chunk_id].append(output_obj) |
|
|
|
|
# self.output_tensors_dw[model_chunk_id].append(output_obj) |
|
|
|
|
|
|
|
|
|
# add output to send_fwd_buffer |
|
|
|
|
if model_chunk_id == 0: # chunk 0 |
|
|
|
@ -732,6 +756,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# chunk0 not last stage; recv output_grad from recv_backward_buffer |
|
|
|
|
else: |
|
|
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) |
|
|
|
|
for h in output_tensor_grad[1]: |
|
|
|
|
h.wait() |
|
|
|
|
output_tensor_grad = output_tensor_grad[0] |
|
|
|
|
else: |
|
|
|
|
# chunk1, is first stage; recv LOSS from local send bwd buffer |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
@ -739,25 +766,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer |
|
|
|
|
else: |
|
|
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) |
|
|
|
|
for h in output_tensor_grad[1]: |
|
|
|
|
h.wait() |
|
|
|
|
output_tensor_grad = output_tensor_grad[0] |
|
|
|
|
|
|
|
|
|
# get input and output object from buffer; |
|
|
|
|
input_obj = self.input_tensors[model_chunk_id].pop(0) |
|
|
|
|
output_obj = self.output_tensors[model_chunk_id].pop(0) |
|
|
|
|
|
|
|
|
|
# # save output_tensor_grad for dw |
|
|
|
|
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# # we save loss here |
|
|
|
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj) |
|
|
|
|
# else: |
|
|
|
|
# # we save output_tensor_grad here |
|
|
|
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) |
|
|
|
|
# the_output_obj_grad = [] |
|
|
|
|
# if isinstance(output_obj, dict): |
|
|
|
|
# for (k, v) in output_obj.items(): |
|
|
|
|
# the_output_obj_grad.append(v.requires_grad) |
|
|
|
|
# else: |
|
|
|
|
# the_output_obj_grad.append(output_obj.requires_grad) |
|
|
|
|
|
|
|
|
|
input_object_grad = self.backward_b_step( |
|
|
|
|
model_chunk=model_chunk, |
|
|
|
|
model_chunk_id=model_chunk_id, |
|
|
|
@ -800,20 +816,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
Returns: |
|
|
|
|
Nothing. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# get y & dy from buffer |
|
|
|
|
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0) |
|
|
|
|
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) |
|
|
|
|
WeightGradStore.pop(chunk=model_chunk_id) |
|
|
|
|
|
|
|
|
|
# self.backward_w_step( |
|
|
|
|
# model_chunk=model_chunk, |
|
|
|
|
# model_chunk_id=model_chunk_id, |
|
|
|
|
# optimizer=optimizer, |
|
|
|
|
# output_obj=output_obj, |
|
|
|
|
# output_obj_grad=output_obj_grad, |
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
def run_forward_only( |
|
|
|
|
self, |
|
|
|
|
model_chunk: Union[ModuleList, Module], |
|
|
|
@ -890,6 +894,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# communication |
|
|
|
|
communication_func = self.communication_map[scheduled_node.type] |
|
|
|
|
wait_handle = communication_func(scheduled_node.chunk) |
|
|
|
|
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle |
|
|
|
|
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}: |
|
|
|
|
self.wait_handles.append(wait_handle) |
|
|
|
|
elif scheduled_node.type == "F": |
|
|
|
|
self.schedule_f( |
|
|
|
@ -914,10 +920,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk_id=scheduled_node.chunk, |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
) |
|
|
|
|
# wait here to ensure all communication is done |
|
|
|
|
for h in self.wait_handles: |
|
|
|
|
for hh in h: |
|
|
|
|
hh.wait() |
|
|
|
|
|
|
|
|
|
# return loss & output |
|
|
|
|
if outputs is not None: |
|
|
|
|
outputs = merge_batch(outputs) |
|
|
|
|