2024-03-29 06:12:29 +00:00
|
|
|
"""
|
|
|
|
Base trainers for online and offline training
|
|
|
|
SLTrainer: supervised learning trainer
|
|
|
|
pretrain, sft, dpo, reward model training
|
|
|
|
OLTrainer: online learning trainer
|
|
|
|
rlhf-ppo
|
|
|
|
"""
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2023-06-29 02:48:09 +00:00
|
|
|
from contextlib import contextmanager
|
2024-03-29 06:12:29 +00:00
|
|
|
from typing import Callable, List
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
import torch.nn as nn
|
|
|
|
import tqdm
|
2023-08-02 02:17:36 +00:00
|
|
|
from coati.experience_buffer import NaiveExperienceBuffer
|
2023-04-18 08:44:03 +00:00
|
|
|
from coati.experience_maker import Experience
|
2023-06-29 02:48:09 +00:00
|
|
|
from torch.optim import Optimizer
|
2023-03-28 12:25:36 +00:00
|
|
|
|
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa329b6d12959fb3c668d278b4b225c5f0.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-08-21 02:47:39 +00:00
|
|
|
from colossalai.booster import Booster, Plugin
|
2024-03-29 06:12:29 +00:00
|
|
|
|
2023-09-20 07:53:58 +00:00
|
|
|
from .utils import is_rank_0
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
class SLTrainer(ABC):
|
2023-03-28 12:25:36 +00:00
|
|
|
"""
|
2023-06-29 02:48:09 +00:00
|
|
|
Base class for supervised learning trainers.
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
strategy (Strategy):the strategy to use for training
|
|
|
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
2023-06-29 02:48:09 +00:00
|
|
|
model (nn.Module): the model to train
|
|
|
|
optim (Optimizer): the optimizer to use for training
|
|
|
|
"""
|
|
|
|
|
2023-07-18 02:59:57 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
2024-03-29 06:12:29 +00:00
|
|
|
booster: Booster,
|
2023-07-18 02:59:57 +00:00
|
|
|
max_epochs: int,
|
|
|
|
model: nn.Module,
|
|
|
|
optimizer: Optimizer,
|
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa329b6d12959fb3c668d278b4b225c5f0.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-08-21 02:47:39 +00:00
|
|
|
plugin: Plugin,
|
2024-03-29 06:12:29 +00:00
|
|
|
start_epoch: int = 0,
|
2023-07-18 02:59:57 +00:00
|
|
|
) -> None:
|
2023-06-29 02:48:09 +00:00
|
|
|
super().__init__()
|
2024-03-29 06:12:29 +00:00
|
|
|
self.booster = booster
|
2023-06-29 02:48:09 +00:00
|
|
|
self.max_epochs = max_epochs
|
|
|
|
self.model = model
|
|
|
|
self.optimizer = optimizer
|
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa329b6d12959fb3c668d278b4b225c5f0.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-08-21 02:47:39 +00:00
|
|
|
self.plugin = plugin
|
2024-03-29 06:12:29 +00:00
|
|
|
self.start_epoch = start_epoch
|
2023-06-29 02:48:09 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _train(self, epoch):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _eval(self, epoch):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2024-03-29 06:12:29 +00:00
|
|
|
@abstractmethod
|
2023-06-29 02:48:09 +00:00
|
|
|
def _before_fit(self):
|
2023-09-20 07:53:58 +00:00
|
|
|
raise NotImplementedError()
|
2023-06-29 02:48:09 +00:00
|
|
|
|
|
|
|
def fit(self, *args, **kwargs):
|
|
|
|
self._before_fit(*args, **kwargs)
|
2024-03-29 06:12:29 +00:00
|
|
|
for epoch in tqdm.trange(self.start_epoch, self.max_epochs, desc="Epochs", disable=not is_rank_0()):
|
2023-06-29 02:48:09 +00:00
|
|
|
self._train(epoch)
|
|
|
|
self._eval(epoch)
|
|
|
|
|
|
|
|
|
2024-03-29 06:12:29 +00:00
|
|
|
class OLTrainer(ABC):
|
2023-06-29 02:48:09 +00:00
|
|
|
"""
|
2024-03-29 06:12:29 +00:00
|
|
|
Base class for online learning trainers, e.g. PPO.
|
2023-06-29 02:48:09 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
strategy (Strategy):the strategy to use for training
|
2023-08-02 02:17:36 +00:00
|
|
|
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
|
2023-06-29 02:48:09 +00:00
|
|
|
sample_buffer (bool, defaults to False): whether to sample from buffer
|
2023-04-18 08:44:03 +00:00
|
|
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
2023-03-28 12:25:36 +00:00
|
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
|
|
"""
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
2024-03-29 06:12:29 +00:00
|
|
|
actor_booster: Booster,
|
|
|
|
critic_booster: Booster,
|
2023-09-19 06:20:26 +00:00
|
|
|
data_buffer: NaiveExperienceBuffer,
|
|
|
|
sample_buffer: bool,
|
|
|
|
dataloader_pin_memory: bool,
|
2024-03-29 06:12:29 +00:00
|
|
|
callbacks: List[Callable] = [],
|
2023-09-19 06:20:26 +00:00
|
|
|
) -> None:
|
2023-03-28 12:25:36 +00:00
|
|
|
super().__init__()
|
2024-03-29 06:12:29 +00:00
|
|
|
self.actor_booster = actor_booster
|
|
|
|
self.critic_booster = critic_booster
|
2023-08-02 02:17:36 +00:00
|
|
|
self.data_buffer = data_buffer
|
2023-06-29 02:48:09 +00:00
|
|
|
self.sample_buffer = sample_buffer
|
2023-03-28 12:25:36 +00:00
|
|
|
self.dataloader_pin_memory = dataloader_pin_memory
|
|
|
|
self.callbacks = callbacks
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
@contextmanager
|
|
|
|
def _fit_ctx(self) -> None:
|
2023-03-28 12:25:36 +00:00
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_fit_start()
|
2023-06-29 02:48:09 +00:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_fit_end()
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def _episode_ctx(self, episode: int) -> None:
|
2023-03-28 12:25:36 +00:00
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_episode_start(episode)
|
2023-06-29 02:48:09 +00:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_episode_end(episode)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def _on_make_experience_start(self) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_make_experience_start()
|
|
|
|
|
|
|
|
def _on_make_experience_end(self, experience: Experience) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_make_experience_end(experience)
|
|
|
|
|
|
|
|
def _on_learn_epoch_start(self, epoch: int) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_epoch_start(epoch)
|
|
|
|
|
|
|
|
def _on_learn_epoch_end(self, epoch: int) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_epoch_end(epoch)
|
|
|
|
|
|
|
|
def _on_learn_batch_start(self) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_batch_start()
|
|
|
|
|
2023-09-20 07:53:58 +00:00
|
|
|
def _on_learn_batch_end(self, experience: Experience) -> None:
|
2023-03-28 12:25:36 +00:00
|
|
|
for callback in self.callbacks:
|
2023-09-20 07:53:58 +00:00
|
|
|
callback.on_learn_batch_end(experience)
|
2023-06-29 02:48:09 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _make_experience(self, collect_step: int):
|
|
|
|
"""
|
|
|
|
Implement this method to make experience.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _learn(self, update_step: int):
|
|
|
|
"""
|
2023-07-18 02:59:57 +00:00
|
|
|
Implement this method to learn from experience, either
|
2023-06-29 02:48:09 +00:00
|
|
|
sample from buffer or transform buffer into dataloader.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2024-03-29 06:12:29 +00:00
|
|
|
@abstractmethod
|
|
|
|
def _setup_update_phrase_dataload(self):
|
|
|
|
"""
|
|
|
|
Implement this method to setup dataloader for update phase.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _save_checkpoint(self, episode: int = 0):
|
|
|
|
"""
|
|
|
|
Implement this method to save checkpoint.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
def _collect_phase(self, collect_step: int):
|
|
|
|
self._on_make_experience_start()
|
|
|
|
experience = self._make_experience(collect_step)
|
|
|
|
self._on_make_experience_end(experience)
|
2023-08-02 02:17:36 +00:00
|
|
|
self.data_buffer.append(experience)
|
2023-06-29 02:48:09 +00:00
|
|
|
|
|
|
|
def _update_phase(self, update_step: int):
|
|
|
|
self._on_learn_epoch_start(update_step)
|
|
|
|
self._learn(update_step)
|
|
|
|
self._on_learn_epoch_end(update_step)
|
|
|
|
|
2023-09-20 07:53:58 +00:00
|
|
|
def _before_fit(self, *args, **kwargs):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2023-07-18 02:59:57 +00:00
|
|
|
def fit(
|
|
|
|
self,
|
|
|
|
num_episodes: int,
|
|
|
|
num_collect_steps: int,
|
|
|
|
num_update_steps: int,
|
2023-09-20 07:53:58 +00:00
|
|
|
*args,
|
|
|
|
**kwargs,
|
2023-07-18 02:59:57 +00:00
|
|
|
):
|
2023-06-29 02:48:09 +00:00
|
|
|
"""
|
|
|
|
The main training loop of on-policy rl trainers.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_episodes (int): the number of episodes to train
|
|
|
|
num_collect_steps (int): the number of collect steps per episode
|
|
|
|
num_update_steps (int): the number of update steps per episode
|
|
|
|
"""
|
2023-09-20 07:53:58 +00:00
|
|
|
self._before_fit(*args, **kwargs)
|
2023-06-29 02:48:09 +00:00
|
|
|
with self._fit_ctx():
|
2023-07-18 02:59:57 +00:00
|
|
|
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
|
2023-06-29 02:48:09 +00:00
|
|
|
with self._episode_ctx(episode):
|
2023-07-18 02:59:57 +00:00
|
|
|
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
|
2023-06-29 02:48:09 +00:00
|
|
|
self._collect_phase(collect_step)
|
|
|
|
if not self.sample_buffer:
|
2024-03-29 06:12:29 +00:00
|
|
|
self._setup_update_phrase_dataload()
|
2023-07-18 02:59:57 +00:00
|
|
|
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
|
2023-06-29 02:48:09 +00:00
|
|
|
self._update_phase(update_step)
|
|
|
|
# NOTE: this is for on-policy algorithms
|
2023-08-02 02:17:36 +00:00
|
|
|
self.data_buffer.clear()
|
2024-03-29 06:12:29 +00:00
|
|
|
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
|
|
|
|
self._save_checkpoint(episode + 1)
|