from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union import torch from chatgpt.experience_maker import Experience, ExperienceMaker from chatgpt.replay_buffer import ReplayBuffer from torch import Tensor from torch.utils.data import DistributedSampler from tqdm import tqdm from .callbacks import Callback from .strategies import Strategy from .utils import is_rank_0 class Trainer(ABC): """ Base class for rlhf trainers. Args: strategy (Strategy):the strategy to use for training experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer replay_buffer (ReplayBuffer): the replay buffer to use for training experience_batch_size (int, defaults to 8): the batch size to use for experience generation max_epochs (int, defaults to 1): the number of epochs of training process tokenizer (Callable, optional): the tokenizer to use for tokenizing the input sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating """ def __init__(self, strategy: Strategy, experience_maker: ExperienceMaker, replay_buffer: ReplayBuffer, experience_batch_size: int = 8, max_epochs: int = 1, tokenizer: Optional[Callable[[Any], dict]] = None, sample_replay_buffer: bool = False, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: super().__init__() self.strategy = strategy self.experience_maker = experience_maker self.replay_buffer = replay_buffer self.experience_batch_size = experience_batch_size self.max_epochs = max_epochs self.tokenizer = tokenizer self.generate_kwargs = generate_kwargs self.sample_replay_buffer = sample_replay_buffer self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks @abstractmethod def training_step(self, experience: Experience) -> Dict[str, Any]: pass def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: if isinstance(inputs, Tensor): return self.experience_maker.make_experience(inputs, **self.generate_kwargs) elif isinstance(inputs, dict): return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) else: raise ValueError(f'Unsupported input type "{type(inputs)}"') def _sample_prompts(self, prompts) -> list: indices = list(range(len(prompts))) sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False) return [prompts[i] for i in sampled_indices] def _learn(self): # replay buffer may be empty at first, we should rebuild at each training if not self.sample_replay_buffer: dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) device = torch.cuda.current_device() if self.sample_replay_buffer: pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) for _ in pbar: experience = self.replay_buffer.sample() metrics = self.training_step(experience) pbar.set_postfix(metrics) else: for epoch in range(self.max_epochs): self._on_learn_epoch_start(epoch) if isinstance(dataloader.sampler, DistributedSampler): dataloader.sampler.set_epoch(epoch) pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) for experience in pbar: self._on_learn_batch_start() experience.to_device(device) metrics = self.training_step(experience) self._on_learn_batch_end(metrics, experience) pbar.set_postfix(metrics) self._on_learn_epoch_end(epoch) def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: time = 0 sampler = self.strategy.setup_sampler(prompts) self._on_fit_start() for episode in range(num_episodes): self._on_episode_start(episode) for timestep in tqdm(range(max_timesteps), desc=f'Episode [{episode+1}/{num_episodes}]', disable=not is_rank_0()): time += 1 rand_prompts = sampler.sample(self.experience_batch_size) if self.tokenizer is not None: inputs = self.tokenizer(rand_prompts) else: inputs = rand_prompts self._on_make_experience_start() experience = self._make_experience(inputs) self._on_make_experience_end(experience) self.replay_buffer.append(experience) if time % update_timesteps == 0: self._learn() self.replay_buffer.clear() self._on_episode_end(episode) self._on_fit_end() # TODO(ver217): maybe simplify these code using context def _on_fit_start(self) -> None: for callback in self.callbacks: callback.on_fit_start() def _on_fit_end(self) -> None: for callback in self.callbacks: callback.on_fit_end() def _on_episode_start(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_start(episode) def _on_episode_end(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_end(episode) def _on_make_experience_start(self) -> None: for callback in self.callbacks: callback.on_make_experience_start() def _on_make_experience_end(self, experience: Experience) -> None: for callback in self.callbacks: callback.on_make_experience_end(experience) def _on_learn_epoch_start(self, epoch: int) -> None: for callback in self.callbacks: callback.on_learn_epoch_start(epoch) def _on_learn_epoch_end(self, epoch: int) -> None: for callback in self.callbacks: callback.on_learn_epoch_end(epoch) def _on_learn_batch_start(self) -> None: for callback in self.callbacks: callback.on_learn_batch_start() def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: for callback in self.callbacks: callback.on_learn_batch_end(metrics, experience)