mirror of https://github.com/hpcaitech/ColossalAI
180 lines
6.8 KiB
Python
180 lines
6.8 KiB
Python
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List
|
|
|
|
import ray
|
|
import torch
|
|
from coati.experience_buffer.utils import BufferItem
|
|
from coati.experience_maker import Experience
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from .callbacks import TrainerCallback
|
|
from .detached_replay_buffer import DetachedReplayBuffer
|
|
from .utils import is_rank_0
|
|
|
|
|
|
class DetachedTrainer(ABC):
|
|
"""
|
|
Base class for detached rlhf trainers.
|
|
'detach' means that the experience maker is detached compared to a normal Trainer.
|
|
Please set name attribute during init:
|
|
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
|
|
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
|
|
Args:
|
|
detached_strategy (DetachedStrategy): the strategy to use for training
|
|
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
|
|
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
experience_maker_holder_name_list: List[str],
|
|
train_batch_size: int = 8,
|
|
buffer_limit: int = 0,
|
|
dataloader_pin_memory: bool = True,
|
|
callbacks: List[TrainerCallback] = [],
|
|
debug: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
|
self.dataloader_pin_memory = dataloader_pin_memory
|
|
self.callbacks = callbacks
|
|
self.target_holder_name_list = experience_maker_holder_name_list
|
|
self.target_holder_list = []
|
|
self._is_target_holder_initialized = False
|
|
self._debug = debug
|
|
|
|
def update_target_holder_list(self):
|
|
# as the length of target_holder_list may be zero, we need to check it by a bool flag
|
|
if not self._is_target_holder_initialized:
|
|
for name in self.target_holder_name_list:
|
|
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
|
self._is_target_holder_initialized = True
|
|
|
|
@abstractmethod
|
|
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
|
|
pass
|
|
|
|
def sync_models_to_remote_makers(self, **kwargs):
|
|
self._update_remote_makers(fully_update=True, **kwargs)
|
|
|
|
@abstractmethod
|
|
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
|
pass
|
|
|
|
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
|
data = []
|
|
# warmup
|
|
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
|
|
self._on_epoch_start(0)
|
|
self._learn_epoch(pbar, data)
|
|
self._on_epoch_end(0)
|
|
# item is already a batch
|
|
dataloader = DataLoader(
|
|
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
|
|
)
|
|
for epoch in range(1, train_epochs):
|
|
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
|
|
self._on_epoch_start(epoch)
|
|
self._learn_epoch(pbar, data)
|
|
self._on_epoch_end(epoch)
|
|
|
|
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
|
|
is_warmup = len(data) == 0
|
|
for x in pbar:
|
|
if self._debug:
|
|
print("[trainer] training step")
|
|
# sample a batch and then train to avoid waiting
|
|
experience = x if not is_warmup else self._buffer_sample()
|
|
experience.to_device(torch.cuda.current_device())
|
|
self._on_batch_start()
|
|
metrics = self.training_step(experience)
|
|
self._on_batch_end(metrics, experience)
|
|
|
|
if self._debug:
|
|
print("[trainer] step over")
|
|
experience.to_device("cpu")
|
|
if is_warmup:
|
|
data.append(experience)
|
|
pbar.set_postfix(metrics)
|
|
|
|
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
|
self._on_fit_start()
|
|
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
|
|
self._on_episode_start(i)
|
|
self._learn(update_steps, train_epochs)
|
|
self._on_update_start()
|
|
self._update_remote_makers()
|
|
self._on_update_end()
|
|
self._on_episode_end(i)
|
|
self._on_fit_end()
|
|
|
|
@ray.method(concurrency_group="buffer_length")
|
|
def buffer_get_length(self):
|
|
# called by ExperienceMakerHolder
|
|
if self._debug:
|
|
print("[trainer] telling length")
|
|
return self.detached_replay_buffer.get_length()
|
|
|
|
@ray.method(concurrency_group="buffer_append")
|
|
def buffer_append(self, experience: Experience):
|
|
# called by ExperienceMakerHolder
|
|
if self._debug:
|
|
print(f"[trainer] receiving exp.")
|
|
self.detached_replay_buffer.append(experience)
|
|
|
|
@ray.method(concurrency_group="buffer_append")
|
|
def buffer_extend(self, items: List[BufferItem]):
|
|
# called by ExperienceMakerHolder
|
|
if self._debug:
|
|
print(f"[trainer] receiving exp.")
|
|
self.detached_replay_buffer.extend(items)
|
|
|
|
@ray.method(concurrency_group="buffer_sample")
|
|
def _buffer_sample(self):
|
|
return self.detached_replay_buffer.sample()
|
|
|
|
def _on_fit_start(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_fit_start()
|
|
|
|
def _on_fit_end(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_fit_end()
|
|
|
|
def _on_episode_start(self, episode: int) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_episode_start(episode)
|
|
|
|
def _on_episode_end(self, episode: int) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_episode_end(episode)
|
|
|
|
def _on_epoch_start(self, epoch: int) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_epoch_start(epoch)
|
|
|
|
def _on_epoch_end(self, epoch: int) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_epoch_end(epoch)
|
|
|
|
def _on_batch_start(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_batch_start()
|
|
|
|
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_batch_end(metrics, experience)
|
|
|
|
def _on_update_start(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_update_start()
|
|
|
|
def _on_update_end(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_update_end()
|