InternLM/internlm/core/scheduler/base_scheduler.py

115 lines
4.7 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable
import torch
from internlm.core.engine import Engine
class BaseScheduler(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggregate all codes that contain the
control of FP16 in class schedule.
Args:
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
them into data and label.
"""
def __init__(self, data_process_func: Callable = None):
self.data_process_func = data_process_func
@abstractmethod
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
"""
pass
def _load_micro_batch(self, data, label, offset, micro_bsz):
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
micro_batch_label = label[offset : offset + micro_bsz]
return micro_batch_data, micro_batch_label
@abstractmethod
def forward_backward_step(
self,
engine: Engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function over a batch of dataset for training or evaluation.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned.
return_output_label (bool, optional): If False, the output and label won't be returned.
"""
pass
@staticmethod
def _call_engine(engine: Engine, inputs: Any):
"""Calls the engine with the given inputs.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
"""
if isinstance(inputs, torch.Tensor):
return engine(inputs)
elif isinstance(inputs, (list, tuple)):
return engine(*inputs)
elif isinstance(inputs, dict):
return engine(**inputs)
else:
raise TypeError(
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
)
@staticmethod
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
"""Calls the engine's criterion with the given outputs and labels.
Args:
engine (internlm.core.Engine): InternLM engine for training and inference.
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
"""
assert isinstance(
outputs, (torch.Tensor, list, tuple, dict)
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
labels = (labels,)
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
return engine.criterion(*outputs, *labels)
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
return engine.criterion(*outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, dict):
return engine.criterion(**outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
else:
raise TypeError(
f"Expected model outputs and labels to be of type torch.Tensor ' \
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)