|
|
|
@ -430,7 +430,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id): |
|
|
|
|
# fwd calculate |
|
|
|
|
internal_inputs = {} if input_obj is None else input_obj |
|
|
|
|
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] |
|
|
|
|
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] |
|
|
|
|
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) |
|
|
|
|
|
|
|
|
|
# last layer in model |
|
|
|
@ -480,22 +480,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_ |
|
|
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
input_obj_, _ = tree_flatten(micro_batch) |
|
|
|
|
output_obj_, _ = tree_flatten(output_obj) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy |
|
|
|
|
input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)}) |
|
|
|
|
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten( |
|
|
|
|
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
) # dy |
|
|
|
|
|
|
|
|
|
# For loss backward; output_obj is loss; output_obj_grad should be None |
|
|
|
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
assert output_obj_grad is None |
|
|
|
|
input_obj_, _ = tree_flatten(input_obj) |
|
|
|
|
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) |
|
|
|
|
output_obj_.append(output_obj) # LOSS |
|
|
|
|
output_obj_grad_.append(output_obj_grad) # None |
|
|
|
|
|
|
|
|
|
# For other chunk stage, use input_obj as input_obj_; |
|
|
|
|
else: |
|
|
|
|
input_obj_, _ = tree_flatten(input_obj) |
|
|
|
|
output_obj_, _ = tree_flatten(output_obj) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy |
|
|
|
|
input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) |
|
|
|
|
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten( |
|
|
|
|
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
) # dy |
|
|
|
|
|
|
|
|
|
optimizer.backward_by_grad( |
|
|
|
|
tensor=output_obj_, |
|
|
|
@ -547,8 +551,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_obj_.append(output_obj) # LOSS |
|
|
|
|
output_obj_grad_.append(None) # None |
|
|
|
|
else: |
|
|
|
|
output_obj_, _ = tree_flatten(output_obj) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy |
|
|
|
|
output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y |
|
|
|
|
output_obj_grad_, _ = tree_flatten( |
|
|
|
|
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
) # dy |
|
|
|
|
|
|
|
|
|
optimizer.backward_by_grad( |
|
|
|
|
tensor=output_obj_, |
|
|
|
|