[feat] add comments for ZBV func;

pull/6034/head
duanjunwen 2024-08-27 07:11:50 +00:00
parent f1c1a87246
commit 1b4bb2beeb
1 changed files with 66 additions and 16 deletions

View File

@ -40,9 +40,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.num_microbatch = num_microbatch
self.collect_non_loss_data = None
self.forward_only = None
self.schedules = schedule
self.it = 0 # curr iteration
# TODO: optim post valid
self.do_post_validation = False
self.is_first_run = True
self.optimizer = None
@ -69,16 +68,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.input_tensors = [[], []]
self.output_tensors = [[], []]
# y & dy buffer for schedule b
# y & dy buffer for schedule w
self.output_tensors_dw = [[], []]
self.output_tensors_grad_dw = [[], []]
# buffer for communication
self.send_forward_buffer = [[], []]
self.recv_forward_buffer = [[], []]
self.send_backward_buffer = [[], []]
self.recv_backward_buffer = [[], []]
self.forward_data_store = []
# y buffer for local send fwd
self.local_send_forward_buffer = []
# dy buffer for local send bwd
self.local_send_backward_buffer = []
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
@ -263,7 +265,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Args:
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
@ -313,7 +314,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Args:
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
@ -371,9 +371,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (ModuleList or Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
@ -410,16 +411,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
) -> Optional[dict]:
"""Backward one step of the pipeline
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
input_obj (Optional[dict]): x.
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
Optional[dict]: dx.
"""
# calculate bwd b step ; only dx = w*dy;
@ -451,10 +454,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
# input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
):
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Nothing need to return; we only calculate dw then update w;
"""
# calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0:
torch.autograd.backward(
@ -481,6 +495,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
):
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Nothing.
"""
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from func param
@ -541,6 +569,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# output_obj: Union[dict, torch.Tensor],
# output_obj_grad: Optional[dict],
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
@ -606,6 +644,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id: int,
# optimizer: OptimizerWrapper,
):
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
@ -629,7 +676,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
return_loss: bool = False,
return_outputs: bool = False,
):
it = self.it
"""
Runs Zerobubble schedule, with communication between pipeline stages.
"""
it = 0
# while we still have schedules_node in self.schedules
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
while it < len(self.schedules):