InternLM/internlm/core/no_pipeline_scheduler.py

280 lines
11 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable
import torch
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
class BaseScheduler(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggregate all codes that contain the
control of FP16 in class schedule.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
them into data and label.
"""
def __init__(self, data_process_func: Callable = None):
self.data_process_func = data_process_func
@abstractmethod
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
@abstractmethod
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function over a batch of dataset for training or evaluation.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned.
return_output_label (bool, optional): If False, the output and label won't be returned.
"""
pass
@staticmethod
def _call_engine(engine: Engine, inputs: Any):
"""Calls the engine with the given inputs.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
"""
if isinstance(inputs, torch.Tensor):
return engine(inputs)
elif isinstance(inputs, (list, tuple)):
return engine(*inputs)
elif isinstance(inputs, dict):
return engine(**inputs)
else:
raise TypeError(
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
)
@staticmethod
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
"""Calls the engine's criterion with the given outputs and labels.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
"""
assert isinstance(
outputs, (torch.Tensor, list, tuple, dict)
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
labels = (labels,)
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
return engine.criterion(*outputs, *labels)
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
return engine.criterion(*outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, dict):
return engine.criterion(**outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
else:
raise TypeError(
f"Expected model outputs and labels to be of type torch.Tensor ' \
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)
class NonPipelineScheduler(BaseScheduler):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
gradient accumulation.
Example:
# this shows an example of customized data_process_func
def data_process_func(dataloader_output):
item1, item2, item3 = dataloader_output
data = (item1, item2)
label = item3
return data, label
"""
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
# check that non-pipeline schedule data process func only takes in one parameter
# which is the batch data
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 1, (
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
"which is a tuple of tensors for the current batch, "
"i.e. data_process_func(dataloader_output)."
)
self._grad_accum_size = gradient_accumulation_size
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
self._grad_accum_offset = 0
super().__init__(data_process_func)
def pre_processing(self, engine: Engine):
"""Performs actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
def _load_accum_batch(self, data: Any, label: Any):
"""Loads a batch of data and label for gradient accumulation.
Args:
data (Any): The data to be loaded.
label (Any): The label to be loaded.
"""
_data = {
k: v[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
for k, v in data.items()
}
_label = label[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
self._grad_accum_offset += self._grad_accum_batch_size
return _data, _label
def _train_one_batch(
self,
data: Any,
label: Any,
engine: Engine,
forward_only: bool = False,
return_loss: bool = True,
scale_loss: int = 1,
):
"""Trains one batch of data.
Args:
data (Any): The data to be trained.
label (Any): The label for the data.
engine (internlm.core.Engine): InternLM engine for training and inference.
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
be executed.
return_loss (bool, optional): Loss will be returned if True.
scale_loss (int, optional): The scale factor for the loss.
"""
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = self._call_engine(engine, data)
if return_loss:
loss = self._call_engine_criterion(engine, output, label)
loss /= scale_loss
# backward
if not forward_only:
engine.backward(loss)
if not return_loss:
loss = None
return output, loss
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data, batch_size = engine.load_batch(data_iter)
assert (
batch_size == self._grad_accum_size
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)
data, label = batch_data
loss = 0 if return_loss else None
outputs = []
labels = []
# reset accumulation microbatch offset
self._grad_accum_offset = 0
for _current_accum_step in range(self._grad_accum_size):
if _current_accum_step == self._grad_accum_size - 1:
engine.optimizer.skip_grad_reduce = False
else:
engine.optimizer.skip_grad_reduce = True
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
)
if return_loss:
loss += _loss
if return_output_label:
outputs.append(_output)
labels.append(_label)
if not return_output_label:
outputs, labels = None, None
return outputs, labels, loss