|
|
|
@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk: Union[ModuleList, Module], |
|
|
|
|
model_chunk_id: int, |
|
|
|
|
optimizer: OptimizerWrapper, |
|
|
|
|
micro_batch: Optional[dict], |
|
|
|
|
# micro_batch: Optional[dict], |
|
|
|
|
input_obj: Optional[dict], |
|
|
|
|
output_obj: Union[dict, torch.Tensor], |
|
|
|
|
output_obj_grad: Optional[dict], |
|
|
|
@ -480,9 +480,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_ |
|
|
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
input_obj_, _ = tree_flatten(micro_batch) |
|
|
|
|
output_obj_, _ = tree_flatten(output_obj) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy |
|
|
|
|
# input_obj_, _ = tree_flatten(micro_batch) |
|
|
|
|
# output_obj_, _ = tree_flatten(output_obj) # y |
|
|
|
|
# output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
# For loss backward; output_obj is loss; output_obj_grad should be None |
|
|
|
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
@ -512,9 +513,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# Format output_obj_grad |
|
|
|
|
input_obj_grad = {} |
|
|
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
for k, v in micro_batch.items(): |
|
|
|
|
if isinstance(v, torch.Tensor) and v.grad is not None: |
|
|
|
|
input_obj_grad[k] = v.grad |
|
|
|
|
# for k, v in micro_batch.items(): |
|
|
|
|
# if isinstance(v, torch.Tensor) and v.grad is not None: |
|
|
|
|
# input_obj_grad[k] = v.grad |
|
|
|
|
pass |
|
|
|
|
else: |
|
|
|
|
for k, v in input_obj.items(): |
|
|
|
|
if isinstance(v, torch.Tensor) and v.grad is not None: |
|
|
|
@ -643,7 +645,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
tree_map(release_tensor_data, output_obj) |
|
|
|
|
|
|
|
|
|
# add input and output object for backward b |
|
|
|
|
self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) |
|
|
|
|
# self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) |
|
|
|
|
self.input_tensors[model_chunk_id].append(input_obj) |
|
|
|
|
|
|
|
|
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj |
|
|
|
|
# Do not release_tensor_data loss, release_tensor_data other output_obj; |
|
|
|
@ -701,7 +704,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) |
|
|
|
|
|
|
|
|
|
# get input and output object from buffer; |
|
|
|
|
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) |
|
|
|
|
# micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) |
|
|
|
|
input_obj = self.input_tensors[model_chunk_id].pop(0) |
|
|
|
|
output_obj = self.output_tensors[model_chunk_id].pop(0) |
|
|
|
|
|
|
|
|
|
# save output_tensor_grad for dw |
|
|
|
@ -717,7 +721,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk=model_chunk, |
|
|
|
|
model_chunk_id=model_chunk_id, |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
micro_batch=micro_batch, |
|
|
|
|
input_obj=input_obj, |
|
|
|
|
output_obj=output_obj, |
|
|
|
|
output_obj_grad=output_tensor_grad, |
|
|
|
@ -838,6 +841,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
# while we still have schedules_node in self.schedules |
|
|
|
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) |
|
|
|
|
print(f"schedule {schedule}") |
|
|
|
|
for it in range(len(schedule)): |
|
|
|
|
scheduled_node = schedule[it] |
|
|
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: |
|
|
|
|