mirror of https://github.com/InternLM/InternLM
191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
|
|
|
import json
|
|
from typing import Iterable, Optional
|
|
|
|
from internlm.core.engine import Engine
|
|
from internlm.core.scheduler import (
|
|
BaseScheduler,
|
|
InterleavedPipelineScheduler,
|
|
NonPipelineScheduler,
|
|
PipelineScheduler,
|
|
)
|
|
|
|
|
|
class TrainState:
|
|
"""
|
|
The TrainState class is used to record the current state of training.
|
|
|
|
Args:
|
|
train_dl (DataLoader): The DataLoader object used for training.
|
|
"""
|
|
|
|
def __init__(self, config, batch_sampler) -> None:
|
|
"""
|
|
Args:
|
|
config (Config): internlm config
|
|
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
|
|
asynchronous and prefetched, the batch_sampler state maintained inside the
|
|
dataloader are faster then the actual training progress, so we copy the
|
|
batch_sampler as the anchor point of ckpt reload.
|
|
"""
|
|
# The number of batches produced by the data iterator
|
|
self.batch_count: int = 0
|
|
# Used to store the number of samples consumed in the current epoch
|
|
self.num_consumed_samples_in_epoch: int = 0
|
|
# Total number of tokens consumed
|
|
self.num_consumed_tokens: int = 0
|
|
# Number of batches skipped due to inf or nan values
|
|
self.inf_nan_skip_batches: int = 0
|
|
# Records the number of updates, skipped batches and inf batches are not counted
|
|
self.step_count: int = 0
|
|
|
|
# Total step count
|
|
self.total_steps: int = config.data.total_steps
|
|
|
|
# resume tensorboard folder, need load from checkpoint or set manually.
|
|
self.resume_tb_folder = config.resume_tb_folder
|
|
|
|
self.tensorboard_folder = config.tensorboard_folder
|
|
|
|
# learning rate
|
|
self.lr = config.adam.lr
|
|
|
|
# smapler state
|
|
if batch_sampler:
|
|
self.init_batch_sampler(batch_sampler)
|
|
|
|
def init_batch_sampler(self, batch_sampler):
|
|
"""
|
|
Args:
|
|
batch_sampler (torch.utils.data.Sampler): sampler.
|
|
"""
|
|
# make a copy of batch_sampler.
|
|
self.batch_sampler = batch_sampler.copy()
|
|
# Iterator for the batch sampler
|
|
self.batch_sampler_iter = iter(self.batch_sampler)
|
|
|
|
def __str__(self) -> str:
|
|
"""Returns a string representation of the training state in JSON format."""
|
|
info = {
|
|
"batch_count": self.batch_count,
|
|
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
|
"num_consumed_tokens": self.num_consumed_tokens,
|
|
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
|
"step_count": self.step_count,
|
|
}
|
|
|
|
return json.dumps(info, indent=4, sort_keys=True)
|
|
|
|
def load_state_dict(self, other_stuffs):
|
|
"""
|
|
Resumes training from a checkpoint.
|
|
|
|
Args:
|
|
other_stuffs (dict): Other information needed to resume training.
|
|
"""
|
|
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
|
|
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
|
|
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
|
|
|
|
# Because the ckpt save occurs after updating 'step_count',
|
|
# there is no need to increment 'step_count' here (Does our step count start from 0 ?),
|
|
# However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
|
|
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
|
self.step_count = other_stuffs.get("step_count", self.batch_count)
|
|
|
|
# resume tensorboard from older tensorboard_folder
|
|
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"batch_count": self.batch_count,
|
|
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
|
"num_consumed_tokens": self.num_consumed_tokens,
|
|
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
|
"step_count": self.step_count,
|
|
"tensorboard_folder": self.tensorboard_folder,
|
|
}
|
|
|
|
|
|
class Trainer:
|
|
"""This is a class tending for easy deployments of users' training and evaluation instead of
|
|
writing their own scripts.
|
|
|
|
Args:
|
|
engine (:class:`Engine`): Engine responsible for the process function.
|
|
schedule (:class:`BaseScheduler`, optional): Runtime schedule. Defaults to None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Engine,
|
|
schedule: Optional[BaseScheduler] = None,
|
|
):
|
|
"""Initializes the Trainer class.
|
|
|
|
Args:
|
|
engine (Engine): The engine responsible for the process function.
|
|
schedule (Optional[BaseScheduler], optional): The runtime schedule. Defaults to None.
|
|
"""
|
|
self._engine = engine
|
|
|
|
# build schedule
|
|
if schedule is None:
|
|
self._schedule = NonPipelineScheduler()
|
|
else:
|
|
assert isinstance(
|
|
schedule, BaseScheduler
|
|
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
|
self._schedule = schedule
|
|
|
|
self._schedule.pre_processing(self._engine)
|
|
|
|
@property
|
|
def engine(self):
|
|
"""Returns the engine that responsible for managing the training and evaluation process."""
|
|
return self._engine
|
|
|
|
@property
|
|
def schedule(self):
|
|
"""Returns the runtime scheduler."""
|
|
return self._schedule
|
|
|
|
@property
|
|
def uses_pipeline(self):
|
|
"""Returns whether the pipeline parallel is used or not."""
|
|
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
|
|
|
|
def train(self):
|
|
"""Sets the model to training mode."""
|
|
self._engine.train()
|
|
|
|
def eval(self):
|
|
"""Sets the model to evaluation mode."""
|
|
self._engine.eval()
|
|
|
|
def zero_grad(self):
|
|
"""Sets the gradient of all parameters in the model to zero."""
|
|
self._engine.zero_grad()
|
|
|
|
def step(self):
|
|
"""Executes the parameter update step."""
|
|
return self._engine.step()
|
|
|
|
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
|
"""Runs the forward, loss computation, and backward for the model.
|
|
Returns a tuple of (output, label, loss).
|
|
|
|
Args:
|
|
data_iter (Iterable): The data iterator.
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
Returns:
|
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
|
"""
|
|
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
|
return output, label, loss
|