from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union import torch from coati.experience_maker import Experience from .callbacks import Callback from .strategies import Strategy class Trainer(ABC): """ Base class for rlhf trainers. Args: strategy (Strategy):the strategy to use for training max_epochs (int, defaults to 1): the number of epochs of training process dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating """ def __init__(self, strategy: Strategy, max_epochs: int = 1, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: super().__init__() self.strategy = strategy self.max_epochs = max_epochs self.generate_kwargs = generate_kwargs self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks # TODO(ver217): maybe simplify these code using context def _on_fit_start(self) -> None: for callback in self.callbacks: callback.on_fit_start() def _on_fit_end(self) -> None: for callback in self.callbacks: callback.on_fit_end() def _on_episode_start(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_start(episode) def _on_episode_end(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_end(episode) def _on_make_experience_start(self) -> None: for callback in self.callbacks: callback.on_make_experience_start() def _on_make_experience_end(self, experience: Experience) -> None: for callback in self.callbacks: callback.on_make_experience_end(experience) def _on_learn_epoch_start(self, epoch: int) -> None: for callback in self.callbacks: callback.on_learn_epoch_start(epoch) def _on_learn_epoch_end(self, epoch: int) -> None: for callback in self.callbacks: callback.on_learn_epoch_end(epoch) def _on_learn_batch_start(self) -> None: for callback in self.callbacks: callback.on_learn_batch_start() def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: for callback in self.callbacks: callback.on_learn_batch_end(metrics, experience)