|
|
|
@ -226,7 +226,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank; |
|
|
|
|
################# |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# return None, [] |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -241,7 +240,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) |
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor) |
|
|
|
|
# return input_tensor, wait_handles |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
@ -265,7 +263,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) |
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor) |
|
|
|
|
# return input_tensor, wait_handles |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: |
|
|
|
@ -313,7 +310,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; get loss from local |
|
|
|
|
################ |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# return None, [] |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
################ |
|
|
|
@ -328,7 +324,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: |
|
|
|
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) |
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) |
|
|
|
|
# return output_tensor_grad, wait_handles |
|
|
|
|
return wait_handles |
|
|
|
|
|
|
|
|
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: |
|
|
|
@ -665,7 +660,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
accum_loss=accum_loss, |
|
|
|
|
outputs=outputs, |
|
|
|
|
) |
|
|
|
|
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};") |
|
|
|
|
|
|
|
|
|
# Step3: |
|
|
|
|
# 3-1:detach output; detach output for send fwd; |
|
|
|
@ -748,20 +742,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
input_obj = self.input_tensors[model_chunk_id].pop(0) |
|
|
|
|
output_obj = self.output_tensors[model_chunk_id].pop(0) |
|
|
|
|
|
|
|
|
|
# # save output_tensor_grad for dw |
|
|
|
|
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# # we save loss here |
|
|
|
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj) |
|
|
|
|
# else: |
|
|
|
|
# # we save output_tensor_grad here |
|
|
|
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) |
|
|
|
|
# the_output_obj_grad = [] |
|
|
|
|
# if isinstance(output_obj, dict): |
|
|
|
|
# for (k, v) in output_obj.items(): |
|
|
|
|
# the_output_obj_grad.append(v.requires_grad) |
|
|
|
|
# else: |
|
|
|
|
# the_output_obj_grad.append(output_obj.requires_grad) |
|
|
|
|
|
|
|
|
|
input_object_grad = self.backward_b_step( |
|
|
|
|
model_chunk=model_chunk, |
|
|
|
|
model_chunk_id=model_chunk_id, |
|
|
|
@ -804,20 +784,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
Returns: |
|
|
|
|
Nothing. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# get y & dy from buffer |
|
|
|
|
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0) |
|
|
|
|
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) |
|
|
|
|
WeightGradStore.pop(chunk=model_chunk_id) |
|
|
|
|
|
|
|
|
|
# self.backward_w_step( |
|
|
|
|
# model_chunk=model_chunk, |
|
|
|
|
# model_chunk_id=model_chunk_id, |
|
|
|
|
# optimizer=optimizer, |
|
|
|
|
# output_obj=output_obj, |
|
|
|
|
# output_obj_grad=output_obj_grad, |
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
def run_forward_only( |
|
|
|
|
self, |
|
|
|
|
model_chunk: Union[ModuleList, Module], |
|
|
|
@ -890,7 +858,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) |
|
|
|
|
for it in range(len(schedule)): |
|
|
|
|
scheduled_node = schedule[it] |
|
|
|
|
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};") |
|
|
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: |
|
|
|
|
# communication |
|
|
|
|
communication_func = self.communication_map[scheduled_node.type] |
|
|
|
|