2023-03-28 12:25:36 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
|
|
|
|
import torch
|
2023-04-18 08:44:03 +00:00
|
|
|
from coati.experience_maker import Experience
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
from .callbacks import Callback
|
|
|
|
from .strategies import Strategy
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer(ABC):
|
|
|
|
"""
|
|
|
|
Base class for rlhf trainers.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
strategy (Strategy):the strategy to use for training
|
|
|
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
2023-04-18 08:44:03 +00:00
|
|
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
2023-03-28 12:25:36 +00:00
|
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
|
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
strategy: Strategy,
|
|
|
|
max_epochs: int = 1,
|
|
|
|
dataloader_pin_memory: bool = True,
|
|
|
|
callbacks: List[Callback] = [],
|
|
|
|
**generate_kwargs) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.strategy = strategy
|
|
|
|
self.max_epochs = max_epochs
|
|
|
|
self.generate_kwargs = generate_kwargs
|
|
|
|
self.dataloader_pin_memory = dataloader_pin_memory
|
|
|
|
self.callbacks = callbacks
|
|
|
|
|
|
|
|
# TODO(ver217): maybe simplify these code using context
|
|
|
|
def _on_fit_start(self) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_fit_start()
|
|
|
|
|
|
|
|
def _on_fit_end(self) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_fit_end()
|
|
|
|
|
|
|
|
def _on_episode_start(self, episode: int) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_episode_start(episode)
|
|
|
|
|
|
|
|
def _on_episode_end(self, episode: int) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_episode_end(episode)
|
|
|
|
|
|
|
|
def _on_make_experience_start(self) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_make_experience_start()
|
|
|
|
|
|
|
|
def _on_make_experience_end(self, experience: Experience) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_make_experience_end(experience)
|
|
|
|
|
|
|
|
def _on_learn_epoch_start(self, epoch: int) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_epoch_start(epoch)
|
|
|
|
|
|
|
|
def _on_learn_epoch_end(self, epoch: int) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_epoch_end(epoch)
|
|
|
|
|
|
|
|
def _on_learn_batch_start(self) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_batch_start()
|
|
|
|
|
|
|
|
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
|
|
for callback in self.callbacks:
|
|
|
|
callback.on_learn_batch_end(metrics, experience)
|