|
|
|
@ -478,11 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_obj_ = []
|
|
|
|
|
output_obj_grad_ = []
|
|
|
|
|
|
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_
|
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
|
|
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# input_obj_, _ = tree_flatten(micro_batch)
|
|
|
|
|
# output_obj_, _ = tree_flatten(output_obj) # y
|
|
|
|
|
# output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
|
|
|
@ -513,9 +510,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# Format output_obj_grad
|
|
|
|
|
input_obj_grad = {}
|
|
|
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# for k, v in micro_batch.items():
|
|
|
|
|
# if isinstance(v, torch.Tensor) and v.grad is not None:
|
|
|
|
|
# input_obj_grad[k] = v.grad
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
for k, v in input_obj.items():
|
|
|
|
@ -645,7 +639,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
tree_map(release_tensor_data, output_obj)
|
|
|
|
|
|
|
|
|
|
# add input and output object for backward b
|
|
|
|
|
# self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
|
|
|
|
self.input_tensors[model_chunk_id].append(input_obj)
|
|
|
|
|
|
|
|
|
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
|
|
|
@ -704,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
|
|
|
|
|
|
|
|
|
# get input and output object from buffer;
|
|
|
|
|
# micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
|
|
|
|
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
|
|
|
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
|
|
|
|
|
|
|
|
@ -841,7 +833,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
# while we still have schedules_node in self.schedules
|
|
|
|
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
|
|
|
|
print(f"schedule {schedule}")
|
|
|
|
|
for it in range(len(schedule)):
|
|
|
|
|
scheduled_node = schedule[it]
|
|
|
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
|
|
|
|