mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.5 KiB
36 lines
1.5 KiB
1 year ago
|
from typing import Any, Callable, Iterable
|
||
|
|
||
|
from torch import Tensor
|
||
|
from torch.nn import Module
|
||
|
|
||
|
from colossalai.interface import OptimizerWrapper
|
||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||
|
|
||
|
|
||
|
class PipelineSchedule:
|
||
|
|
||
|
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||
|
self.stage_manager = stage_manager
|
||
|
|
||
|
def forward_backward_step(self,
|
||
|
model: Module,
|
||
|
optimizer: OptimizerWrapper,
|
||
|
data_iter: Iterable,
|
||
|
criterion: Callable[[Any, Any], Tensor],
|
||
|
return_loss: bool = False,
|
||
|
return_outputs: bool = False) -> dict:
|
||
|
"""Forward and backward step for pipeline training.
|
||
|
|
||
|
Args:
|
||
|
model (Module): Model to be trained.
|
||
|
optimizer (OptimizerWrapper): Optimizer to be used.
|
||
|
data_iter (Iterable): Data iterator.
|
||
|
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||
|
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||
|
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||
|
|
||
|
Returns:
|
||
|
dict: A dict with keys: 'loss' and 'outputs'.
|
||
|
"""
|
||
|
raise NotImplementedError
|