You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/Chat/coati/trainer/base.py

191 lines
6.7 KiB

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import List
import torch.nn as nn
import tqdm
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .callbacks import Callback
from .strategies import Strategy
from .utils import CycledDataLoader, is_rank_0
class SLTrainer(ABC):
"""
Base class for supervised learning trainers.
Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
model (nn.Module): the model to train
optim (Optimizer): the optimizer to use for training
"""
def __init__(
self,
strategy: Strategy,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
@abstractmethod
def _train(self, epoch):
raise NotImplementedError()
@abstractmethod
def _eval(self, epoch):
raise NotImplementedError()
def _before_fit(self):
self.no_epoch_bar = False
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
self._train(epoch)
self._eval(epoch)
class OnPolicyTrainer(ABC):
"""
Base class for on-policy rl trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(
self,
strategy: Strategy,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = [],
) -> None:
super().__init__()
self.strategy = strategy
self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@contextmanager
def _fit_ctx(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
try:
yield
finally:
for callback in self.callbacks:
callback.on_fit_end()
@contextmanager
def _episode_ctx(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
try:
yield
finally:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_learn_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_start(epoch)
def _on_learn_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_end(epoch)
def _on_learn_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)
@abstractmethod
def _make_experience(self, collect_step: int):
"""
Implement this method to make experience.
"""
raise NotImplementedError()
@abstractmethod
def _learn(self, update_step: int):
"""
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()
def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
self.data_buffer.append(experience)
def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
self._learn(update_step)
self._on_learn_epoch_end(update_step)
def fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
):
"""
The main training loop of on-policy rl trainers.
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx():
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()