You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalChat/coati/trainer/base.py

215 lines
6.8 KiB

"""
Base trainers for online and offline training
SLTrainer: supervised learning trainer
pretrain, sft, dpo, reward model training
OLTrainer: online learning trainer
rlhf-ppo
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Callable, List
import torch.nn as nn
import tqdm
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
from colossalai.booster import Booster
from .utils import is_rank_0
class SLTrainer(ABC):
"""
Base class for supervised learning trainers.
Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
model (nn.Module): the model to train
optim (Optimizer): the optimizer to use for training
"""
def __init__(
self,
booster: Booster,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
start_epoch: int = 0,
) -> None:
super().__init__()
self.booster = booster
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
self.start_epoch = start_epoch
@abstractmethod
def _train(self, epoch):
raise NotImplementedError()
@abstractmethod
def _eval(self, epoch):
raise NotImplementedError()
@abstractmethod
def _before_fit(self):
raise NotImplementedError()
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.start_epoch, self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch)
self._eval(epoch)
class OLTrainer(ABC):
"""
Base class for online learning trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(
self,
actor_booster: Booster,
critic_booster: Booster,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callable] = [],
) -> None:
super().__init__()
self.actor_booster = actor_booster
self.critic_booster = critic_booster
self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@contextmanager
def _fit_ctx(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
try:
yield
finally:
for callback in self.callbacks:
callback.on_fit_end()
@contextmanager
def _episode_ctx(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
try:
yield
finally:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_learn_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_start(epoch)
def _on_learn_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_end(epoch)
def _on_learn_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(experience)
@abstractmethod
def _make_experience(self, collect_step: int):
"""
Implement this method to make experience.
"""
raise NotImplementedError()
@abstractmethod
def _learn(self, update_step: int):
"""
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()
@abstractmethod
def _setup_update_phrase_dataload(self):
"""
Implement this method to setup dataloader for update phase.
"""
raise NotImplementedError()
@abstractmethod
def _save_checkpoint(self, episode: int = 0):
"""
Implement this method to save checkpoint.
"""
raise NotImplementedError()
def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
self.data_buffer.append(experience)
def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
self._learn(update_step)
self._on_learn_epoch_end(update_step)
def _before_fit(self, *args, **kwargs):
raise NotImplementedError()
def fit(
self,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
*args,
**kwargs,
):
"""
The main training loop of on-policy rl trainers.
Args:
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self._before_fit(*args, **kwargs)
with self._fit_ctx():
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
self._setup_update_phrase_dataload()
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
self._save_checkpoint(episode + 1)