mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.7 KiB
100 lines
3.7 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import inspect
|
|
from typing import Callable, Iterable
|
|
|
|
import torch
|
|
|
|
from colossalai.utils import conditional_context
|
|
|
|
from ._base_schedule import BaseSchedule
|
|
|
|
|
|
class NonPipelineSchedule(BaseSchedule):
|
|
"""A helper schedule class for no pipeline parallelism running environment.
|
|
During one process, it loads a batch of dataset and feeds it to the model.
|
|
After getting the output and calculating the loss, it will use :meth:`step`
|
|
to update the parameters if it is in training mode.
|
|
|
|
Args:
|
|
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
|
|
and returns a tuple in the form of (data, label).
|
|
and it will be executed in load_batch.
|
|
|
|
Example:
|
|
# this shows an example of customized data_process_func
|
|
def data_process_func(dataloader_output):
|
|
item1, item2, item3 = dataloader_output
|
|
data = (item1, item2)
|
|
label = item3
|
|
return data, label
|
|
"""
|
|
|
|
def __init__(self, data_process_func: Callable = None):
|
|
# check that non-pipeline schedule data process func only takes in one parameter
|
|
# which is the batch data
|
|
|
|
if data_process_func:
|
|
sig = inspect.signature(data_process_func)
|
|
assert len(sig.parameters) == 1, (
|
|
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
|
"which is a tuple of tensors for the current batch, "
|
|
"i.e. data_process_func(dataloader_output)."
|
|
)
|
|
|
|
super().__init__(data_process_func)
|
|
|
|
def forward_backward_step(
|
|
self,
|
|
engine,
|
|
data_iter: Iterable,
|
|
forward_only: bool = False,
|
|
return_loss: bool = True,
|
|
return_output_label: bool = True,
|
|
):
|
|
"""The process function that loads a batch of dataset and feeds it to the model.
|
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
|
|
|
Args:
|
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
|
forward_only (bool, optional):
|
|
If True, the model is run for the forward pass, else back propagation will be executed.
|
|
return_loss (bool, optional): Loss will be returned if True.
|
|
return_output_label (bool, optional): Output and label will be returned if True.
|
|
|
|
Returns:
|
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
|
"""
|
|
assert (
|
|
forward_only or return_loss
|
|
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
|
batch_data = self.load_batch(data_iter)
|
|
if self.data_process_func:
|
|
data, label = self.data_process_func(batch_data)
|
|
else:
|
|
# if not batch data process func is given,
|
|
# then we regard the batch data as a simple tuple of (data, label)
|
|
data, label = batch_data
|
|
|
|
# forward
|
|
with conditional_context(torch.no_grad(), enable=forward_only):
|
|
output = self._call_engine(engine, data)
|
|
if return_loss:
|
|
loss = self._call_engine_criterion(engine, output, label)
|
|
|
|
if not forward_only:
|
|
engine.backward(loss)
|
|
|
|
if return_output_label:
|
|
if return_loss:
|
|
return output, label, loss
|
|
else:
|
|
return output, label, None
|
|
else:
|
|
if return_loss:
|
|
return None, None, loss
|
|
else:
|
|
return None, None, None
|