|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
|
|
|
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.cuda
|
|
|
|
@ -696,6 +696,54 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_obj_grad=output_obj_grad,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def run_forward_only(
|
|
|
|
|
self,
|
|
|
|
|
model_chunk: Union[ModuleList, Module],
|
|
|
|
|
data_iter: Iterable,
|
|
|
|
|
criterion: Callable[..., Any],
|
|
|
|
|
return_loss: bool = False,
|
|
|
|
|
return_outputs: bool = False,
|
|
|
|
|
) -> Dict:
|
|
|
|
|
assert self.forward_only
|
|
|
|
|
|
|
|
|
|
# prepare batch
|
|
|
|
|
self.load_batch(data_iter)
|
|
|
|
|
|
|
|
|
|
# prepare accum loss & output
|
|
|
|
|
accum_loss = None
|
|
|
|
|
|
|
|
|
|
# reset accum loss at fwd end;
|
|
|
|
|
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
|
|
|
|
|
|
|
|
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
|
|
|
|
|
|
|
|
|
it = 0
|
|
|
|
|
# while we still have schedules_node in self.schedules
|
|
|
|
|
while it < len(self.schedules):
|
|
|
|
|
scheduled_node = self.schedules[it]
|
|
|
|
|
|
|
|
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
|
|
|
|
# communication
|
|
|
|
|
if scheduled_node.type == "RECV_FORWARD":
|
|
|
|
|
self.recv_forward(scheduled_node.chunk)
|
|
|
|
|
elif scheduled_node.type == "SEND_FORWARD":
|
|
|
|
|
self.send_forward(scheduled_node.chunk)
|
|
|
|
|
if scheduled_node.type == "F":
|
|
|
|
|
self.schedule_f(
|
|
|
|
|
scheduled_node=scheduled_node,
|
|
|
|
|
model_chunk=model_chunk,
|
|
|
|
|
model_chunk_id=scheduled_node.chunk,
|
|
|
|
|
criterion=criterion,
|
|
|
|
|
accum_loss=accum_loss,
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
)
|
|
|
|
|
it += 1
|
|
|
|
|
# return loss & output
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
outputs = merge_batch(outputs)
|
|
|
|
|
return {"loss": accum_loss, "outputs": outputs}
|
|
|
|
|
|
|
|
|
|
def run_forward_backward(
|
|
|
|
|
self,
|
|
|
|
|
model_chunk: Union[ModuleList, Module],
|
|
|
|
@ -704,7 +752,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
optimizer: Optional[OptimizerWrapper] = None,
|
|
|
|
|
return_loss: bool = False,
|
|
|
|
|
return_outputs: bool = False,
|
|
|
|
|
):
|
|
|
|
|
) -> Dict:
|
|
|
|
|
"""
|
|
|
|
|
Runs Zerobubble schedule, with communication between pipeline stages.
|
|
|
|
|
"""
|
|
|
|
@ -770,3 +818,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
outputs = merge_batch(outputs)
|
|
|
|
|
return {"loss": accum_loss, "outputs": outputs}
|
|
|
|
|
|
|
|
|
|
def forward_backward_step(
|
|
|
|
|
self,
|
|
|
|
|
model_chunk: Union[ModuleList, Module],
|
|
|
|
|
data_iter: Iterable,
|
|
|
|
|
criterion: Callable[..., Any],
|
|
|
|
|
optimizer: Optional[OptimizerWrapper] = None,
|
|
|
|
|
return_loss: bool = False,
|
|
|
|
|
return_outputs: bool = False,
|
|
|
|
|
) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
|
|
|
|
data_iter (Iterable): Data iterator.
|
|
|
|
|
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
|
|
|
|
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
|
|
|
|
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
|
|
|
|
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict: A dict with keys: 'loss' and 'outputs'.
|
|
|
|
|
"""
|
|
|
|
|
self.forward_only = not torch.is_grad_enabled()
|
|
|
|
|
if optimizer is None:
|
|
|
|
|
assert self.forward_only, "Optimizer should be passed when doing backward."
|
|
|
|
|
|
|
|
|
|
if self.forward_only:
|
|
|
|
|
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
|
|
|
|
else:
|
|
|
|
|
result = self.run_forward_backward(
|
|
|
|
|
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|