mirror of https://github.com/hpcaitech/ColossalAI
[feat] add comments for ZBV func;
parent
f1c1a87246
commit
1b4bb2beeb
|
@ -40,9 +40,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self.num_microbatch = num_microbatch
|
||||
self.collect_non_loss_data = None
|
||||
self.forward_only = None
|
||||
|
||||
self.schedules = schedule
|
||||
self.it = 0 # curr iteration
|
||||
# TODO: optim post valid
|
||||
self.do_post_validation = False
|
||||
self.is_first_run = True
|
||||
self.optimizer = None
|
||||
|
@ -69,16 +68,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self.input_tensors = [[], []]
|
||||
self.output_tensors = [[], []]
|
||||
|
||||
# y & dy buffer for schedule b
|
||||
# y & dy buffer for schedule w
|
||||
self.output_tensors_dw = [[], []]
|
||||
self.output_tensors_grad_dw = [[], []]
|
||||
|
||||
# buffer for communication
|
||||
self.send_forward_buffer = [[], []]
|
||||
self.recv_forward_buffer = [[], []]
|
||||
self.send_backward_buffer = [[], []]
|
||||
self.recv_backward_buffer = [[], []]
|
||||
self.forward_data_store = []
|
||||
|
||||
# y buffer for local send fwd
|
||||
self.local_send_forward_buffer = []
|
||||
# dy buffer for local send bwd
|
||||
self.local_send_backward_buffer = []
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
|
@ -263,7 +265,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
|
@ -313,7 +314,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
|
@ -371,9 +371,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model (ModuleList or Module): Model Chunk to be run
|
||||
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
|
||||
criterion (Callable): Criterion to calculate loss.
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
input_obj (Optional[dict]): x;
|
||||
criterion (Callable): loss function;
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
|
||||
|
@ -410,16 +411,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
) -> Optional[dict]:
|
||||
"""Backward one step of the pipeline
|
||||
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
|
||||
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
|
||||
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
|
||||
input_obj (Optional[dict]): x.
|
||||
output_obj (Union[dict, torch.Tensor]): y.
|
||||
output_obj_grad (dict): dy.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
|
||||
Optional[dict]: dx.
|
||||
"""
|
||||
# calculate bwd b step ; only dx = w*dy;
|
||||
|
||||
|
@ -451,10 +454,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
|
||||
|
||||
Args:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||
output_obj (Union[dict, torch.Tensor]): y.
|
||||
output_obj_grad (dict): dy.
|
||||
|
||||
Returns:
|
||||
Nothing need to return; we only calculate dw then update w;
|
||||
"""
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
if model_chunk_id == 0:
|
||||
torch.autograd.backward(
|
||||
|
@ -481,6 +495,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
):
|
||||
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
|
||||
|
||||
Args:
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
input_obj (Optional[dict]): x;
|
||||
criterion (Callable): loss function;
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
|
@ -541,6 +569,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||
|
||||
Args:
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
|
||||
# Step1: recv bwd
|
||||
if model_chunk_id == 0:
|
||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||
|
@ -606,6 +644,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
):
|
||||
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
||||
|
||||
Args:
|
||||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
|
||||
# get y & dy from buffer
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||
|
@ -629,7 +676,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
):
|
||||
it = self.it
|
||||
"""
|
||||
Runs Zerobubble schedule, with communication between pipeline stages.
|
||||
"""
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
|
||||
while it < len(self.schedules):
|
||||
|
|
Loading…
Reference in New Issue