You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/pipeline/schedule/base.py

37 lines
1.4 KiB

from typing import Any, Callable, Iterable, Optional
from torch import Tensor
from torch.nn import Module
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.stage_manager import PipelineStageManager
class PipelineSchedule:
def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Forward and backward step for pipeline training.
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
raise NotImplementedError