mirror of https://github.com/hpcaitech/ColossalAI
163 lines
6.9 KiB
Python
163 lines
6.9 KiB
Python
|
import random
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
from chatgpt.experience_maker import Experience, ExperienceMaker
|
||
|
from chatgpt.replay_buffer import ReplayBuffer
|
||
|
from torch import Tensor
|
||
|
from torch.utils.data import DistributedSampler
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from .callbacks import Callback
|
||
|
from .strategies import Strategy
|
||
|
from .utils import is_rank_0
|
||
|
|
||
|
|
||
|
class Trainer(ABC):
|
||
|
"""
|
||
|
Base class for rlhf trainers.
|
||
|
|
||
|
Args:
|
||
|
strategy (Strategy):the strategy to use for training
|
||
|
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
|
||
|
replay_buffer (ReplayBuffer): the replay buffer to use for training
|
||
|
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
||
|
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
|
||
|
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||
|
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
strategy: Strategy,
|
||
|
experience_maker: ExperienceMaker,
|
||
|
replay_buffer: ReplayBuffer,
|
||
|
experience_batch_size: int = 8,
|
||
|
max_epochs: int = 1,
|
||
|
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||
|
sample_replay_buffer: bool = False,
|
||
|
dataloader_pin_memory: bool = True,
|
||
|
callbacks: List[Callback] = [],
|
||
|
**generate_kwargs) -> None:
|
||
|
super().__init__()
|
||
|
self.strategy = strategy
|
||
|
self.experience_maker = experience_maker
|
||
|
self.replay_buffer = replay_buffer
|
||
|
self.experience_batch_size = experience_batch_size
|
||
|
self.max_epochs = max_epochs
|
||
|
self.tokenizer = tokenizer
|
||
|
self.generate_kwargs = generate_kwargs
|
||
|
self.sample_replay_buffer = sample_replay_buffer
|
||
|
self.dataloader_pin_memory = dataloader_pin_memory
|
||
|
self.callbacks = callbacks
|
||
|
|
||
|
@abstractmethod
|
||
|
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||
|
pass
|
||
|
|
||
|
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||
|
if isinstance(inputs, Tensor):
|
||
|
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||
|
elif isinstance(inputs, dict):
|
||
|
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||
|
|
||
|
def _sample_prompts(self, prompts) -> list:
|
||
|
indices = list(range(len(prompts)))
|
||
|
sampled_indices = random.sample(indices, self.experience_batch_size)
|
||
|
return [prompts[i] for i in sampled_indices]
|
||
|
|
||
|
def _learn(self):
|
||
|
# replay buffer may be empty at first, we should rebuild at each training
|
||
|
if not self.sample_replay_buffer:
|
||
|
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
||
|
device = torch.cuda.current_device()
|
||
|
if self.sample_replay_buffer:
|
||
|
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||
|
for _ in pbar:
|
||
|
experience = self.replay_buffer.sample()
|
||
|
metrics = self.training_step(experience)
|
||
|
pbar.set_postfix(metrics)
|
||
|
else:
|
||
|
for epoch in range(self.max_epochs):
|
||
|
self._on_learn_epoch_start(epoch)
|
||
|
if isinstance(dataloader.sampler, DistributedSampler):
|
||
|
dataloader.sampler.set_epoch(epoch)
|
||
|
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
||
|
for experience in pbar:
|
||
|
self._on_learn_batch_start()
|
||
|
experience.to_device(device)
|
||
|
metrics = self.training_step(experience)
|
||
|
self._on_learn_batch_end(metrics, experience)
|
||
|
pbar.set_postfix(metrics)
|
||
|
self._on_learn_epoch_end(epoch)
|
||
|
|
||
|
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
||
|
time = 0
|
||
|
self._on_fit_start()
|
||
|
for episode in range(num_episodes):
|
||
|
self._on_episode_start(episode)
|
||
|
for timestep in tqdm(range(max_timesteps),
|
||
|
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||
|
disable=not is_rank_0()):
|
||
|
time += 1
|
||
|
rand_prompts = self._sample_prompts(prompts)
|
||
|
if self.tokenizer is not None:
|
||
|
inputs = self.tokenizer(rand_prompts)
|
||
|
else:
|
||
|
inputs = rand_prompts
|
||
|
self._on_make_experience_start()
|
||
|
experience = self._make_experience(inputs)
|
||
|
self._on_make_experience_end(experience)
|
||
|
self.replay_buffer.append(experience)
|
||
|
if time % update_timesteps == 0:
|
||
|
self._learn()
|
||
|
self.replay_buffer.clear()
|
||
|
self._on_episode_end(episode)
|
||
|
self._on_fit_end()
|
||
|
|
||
|
# TODO(ver217): maybe simplify these code using context
|
||
|
def _on_fit_start(self) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_fit_start()
|
||
|
|
||
|
def _on_fit_end(self) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_fit_end()
|
||
|
|
||
|
def _on_episode_start(self, episode: int) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_episode_start(episode)
|
||
|
|
||
|
def _on_episode_end(self, episode: int) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_episode_end(episode)
|
||
|
|
||
|
def _on_make_experience_start(self) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_make_experience_start()
|
||
|
|
||
|
def _on_make_experience_end(self, experience: Experience) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_make_experience_end(experience)
|
||
|
|
||
|
def _on_learn_epoch_start(self, epoch: int) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_learn_epoch_start(epoch)
|
||
|
|
||
|
def _on_learn_epoch_end(self, epoch: int) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_learn_epoch_end(epoch)
|
||
|
|
||
|
def _on_learn_batch_start(self) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_learn_batch_start()
|
||
|
|
||
|
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||
|
for callback in self.callbacks:
|
||
|
callback.on_learn_batch_end(metrics, experience)
|