[chat] refactor trainer class (#4080)

* to: add SLTrainer

* refactor: refactor RMTrainer and SFTTrainer

* fix: fix init file

* feat: remove on_learn_epoch fn as not used

* fix: align with modified gemini arguments

* to: add OnPolicyTrainer

* revert: add _on_learn_epoch fn

* refactor: refactor PPOTrainer

* style: rename PPOTrainer argument

* fix: align with modified PPO arguments

* test: align with modified train_prompts arguments

* chore: modify train_prompts

* docs: align with modified arguments

* fix: remove unnecessary output

* fix: move dataloader to fit fn of SLTrainer

* fix: move dataloader to fit fn of OnPolicyTrainer

* fix: modify usage of prompt and pretrain dataloader
pull/4122/head
Wenhao Chen 2023-06-29 10:48:09 +08:00 committed by GitHub
parent 711e2b4c00
commit b03d64d010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 461 additions and 361 deletions

View File

@ -83,7 +83,7 @@ More details can be found in the latest news.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ColossalChat%20Speed.jpg" width=450/>
</p>
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --max_timesteps 1 --update_timesteps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
## Install

View File

@ -137,6 +137,12 @@ def main(args):
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)
trainer = PPOTrainer(strategy,
actor,
critic,
@ -145,7 +151,6 @@ def main(args):
actor_optim,
critic_optim,
ptx_coef=0,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models,
max_length=512,
@ -157,17 +162,11 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)
trainer.fit(dataloader,
None,
trainer.fit(prompt_dataloader=dataloader,
pretrain_dataloader=None,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
@ -183,9 +182,8 @@ if __name__ == '__main__':
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--num_collect_steps', type=int, default=8)
parser.add_argument('--num_update_steps', type=int, default=1)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0)

View File

@ -1,6 +1,10 @@
from .base import Trainer
from .base import OnPolicyTrainer, SLTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
__all__ = [
'SLTrainer', 'OnPolicyTrainer',
'RewardModelTrainer', 'SFTTrainer',
'PPOTrainer'
]

View File

@ -1,54 +1,108 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from contextlib import contextmanager
from typing import List
import torch
import torch.nn as nn
import tqdm
from coati.experience_maker import Experience
from coati.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .callbacks import Callback
from .strategies import Strategy
from .utils import CycledDataLoader, is_rank_0
class Trainer(ABC):
class SLTrainer(ABC):
"""
Base class for rlhf trainers.
Base class for supervised learning trainers.
Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
model (nn.Module): the model to train
optim (Optimizer): the optimizer to use for training
"""
def __init__(self,
strategy: Strategy,
max_epochs: int = 1,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
self.generate_kwargs = generate_kwargs
self.model = model
self.optimizer = optimizer
@abstractmethod
def _train(self, epoch):
raise NotImplementedError()
@abstractmethod
def _eval(self, epoch):
raise NotImplementedError()
def _before_fit(self):
self.no_epoch_bar = False
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs,
desc="Epochs",
disable=not is_rank_0() or self.no_epoch_bar
):
self._train(epoch)
self._eval(epoch)
class OnPolicyTrainer(ABC):
"""
Base class for on-policy rl trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
buffer (NaiveReplayBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(self,
strategy: Strategy,
buffer: NaiveReplayBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []
) -> None:
super().__init__()
self.strategy = strategy
self.buffer = buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
@contextmanager
def _fit_ctx(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
try:
yield
finally:
for callback in self.callbacks:
callback.on_fit_end()
def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()
def _on_episode_start(self, episode: int) -> None:
@contextmanager
def _episode_ctx(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
try:
yield
finally:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
@ -73,3 +127,71 @@ class Trainer(ABC):
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)
@abstractmethod
def _make_experience(self, collect_step: int):
"""
Implement this method to make experience.
"""
raise NotImplementedError()
@abstractmethod
def _learn(self, update_step: int):
"""
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()
def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
self.buffer.append(experience)
def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
self._learn(update_step)
self._on_learn_epoch_end(update_step)
def fit(self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
):
"""
The main training loop of on-policy rl trainers.
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx():
for episode in tqdm.trange(num_episodes,
desc="Episodes",
disable=not is_rank_0()):
with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps,
desc="Collect steps",
disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer,
self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps,
desc="Update steps",
disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.buffer.clear()

View File

@ -1,6 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Dict, List
import torch
import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, get_base_model
@ -9,19 +8,32 @@ from coati.models.utils import calc_action_log_probs
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.utils import get_current_device
from .base import Trainer
from .base import OnPolicyTrainer
from .callbacks import Callback
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device
class PPOTrainer(Trainer):
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
return new_kwargs
class PPOTrainer(OnPolicyTrainer):
"""
Trainer for PPO algorithm.
@ -35,14 +47,13 @@ class PPOTrainer(Trainer):
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
max_epochs (int, defaults to 1): the number of epochs of training process
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
@ -65,25 +76,26 @@ class PPOTrainer(Trainer):
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
max_epochs: int = 1,
sample_replay_buffer: bool = False,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
**generate_kwargs
) -> None:
if isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
"GeminiPlugin is not compatible with manual model.to('cpu')"
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(
strategy, buffer,
sample_buffer, dataloader_pin_memory,
callbacks
)
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.sample_replay_buffer = sample_replay_buffer
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
self.offload_inference_models = offload_inference_models
self.actor = actor
@ -99,76 +111,20 @@ class PPOTrainer(Trainer):
self.device = get_current_device()
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
def _make_experience(self, collect_step: int) -> Experience:
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
if isinstance(prompts, Tensor):
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
elif isinstance(prompts, dict):
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
raise ValueError(f'Unsupported input type "{type(prompts)}"')
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
experience.to_device(self.device)
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
if self.offload_inference_models:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
def training_step(self, experience: Experience) -> Dict[str, float]:
def _training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()
# policy loss
@ -182,7 +138,7 @@ class PPOTrainer(Trainer):
# ptx loss
if self.ptx_coef != 0:
batch = next(iter(self.pretrain_dataloader))
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
@ -208,16 +164,29 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()}
def _learn(self, update_step: int):
if self.offload_inference_models:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
return new_kwargs
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self._training_step(experience)
self._on_learn_batch_end(metrics, experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(
self.dataloader,
desc=f'Train epoch [{update_step + 1}]',
disable=not is_rank_0()
)
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self._training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)

View File

@ -1,20 +1,19 @@
from datetime import datetime
from typing import Callable, List
from typing import Callable
import pandas as pd
import torch
import tqdm
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from .base import Trainer
from .callbacks import Callback
from .base import SLTrainer
from .strategies import Strategy
from .utils import is_rank_0
class RewardModelTrainer(Trainer):
class RewardModelTrainer(SLTrainer):
"""
Trainer to use while training reward model.
@ -24,12 +23,7 @@ class RewardModelTrainer(Trainer):
optim (Optimizer): the optimizer to use for training
lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(
@ -39,87 +33,79 @@ class RewardModelTrainer(Trainer):
optim: Optimizer,
lr_scheduler: _LRScheduler,
loss_fn: Callable,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader,
max_epochs: int = 1,
callbacks: List[Callback] = [],
) -> None:
super().__init__(strategy, max_epochs, callbacks=callbacks)
super().__init__(strategy, max_epochs, model, optim)
self.loss_fn = loss_fn
self.scheduler = lr_scheduler
def _eval(self, epoch):
if self.eval_dataloader is not None:
self.model.eval()
dist, on, cnt = 0, 0, 0
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)):
cnt += 1
if chosen_reward[i] > reject_reward[i]:
on += 1
dist += (chosen_reward - reject_reward).mean().item()
self.dist = dist / len(self.eval_dataloader)
self.acc = on / cnt
if is_rank_0():
log = pd.DataFrame(
[[(epoch + 1) * len(self.train_dataloader),
self.loss.item(), self.dist, self.acc]],
columns=['step', 'loss', 'dist', 'acc']
)
log.to_csv('log.csv', mode='a', header=False, index=False)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0()
)
cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
self.loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(self.loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
cnt += 1
if cnt % 100 == 0:
self.scheduler.step()
step_bar.update()
step_bar.close()
def _before_fit(self,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
super()._before_fit()
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader
self.model = model
self.loss_fn = loss_fn
self.optimizer = optim
self.scheduler = lr_scheduler
def eval_acc(self, dataloader):
dist = 0
on = 0
cnt = 0
self.model.eval()
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)):
cnt += 1
if chosen_reward[i] > reject_reward[i]:
on += 1
dist += (chosen_reward - reject_reward).mean().item()
dist_mean = dist / len(dataloader)
acc = on / cnt
self.model.train()
return dist_mean, acc
def fit(self):
time = datetime.now()
epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.max_epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
# train
self.model.train()
cnt = 0
acc = 0
dist = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
cnt += 1
if cnt == 100:
self.scheduler.step()
dist, acc = self.eval_acc(self.valid_dataloader)
cnt = 0
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
step_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc})
# eval
dist, acc = self.eval_acc(self.eval_dataloader)
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log.csv', mode='a', header=False, index=False)
epoch_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc})
step_bar.close()

View File

@ -1,21 +1,22 @@
import time
from typing import List
from typing import Optional
import torch
import torch.distributed as dist
import tqdm
import wandb
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from .base import Trainer
from .callbacks import Callback
from colossalai.logging import DistributedLogger
from .base import SLTrainer
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device
class SFTTrainer(Trainer):
class SFTTrainer(SLTrainer):
"""
Trainer to use while training reward model.
@ -23,12 +24,9 @@ class SFTTrainer(Trainer):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
lr_scheduler(_LRScheduler): the lr scheduler to use for training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
"""
def __init__(
@ -37,95 +35,92 @@ class SFTTrainer(Trainer):
strategy: Strategy,
optim: Optimizer,
lr_scheduler: _LRScheduler,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
max_epochs: int = 2,
accumulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not isinstance(strategy.plugin, GeminiPlugin), \
"Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.model = model
self.optimizer = optim
super().__init__(strategy, max_epochs, model, optim)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
def fit(self, logger, use_wandb: bool = False):
def _train(self, epoch: int):
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
self.total_loss += loss.item()
# gradient accumulation
if (batch_id + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and self.use_wandb:
wandb.log({
"loss": self.total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
self.total_loss = 0
self.step_bar.update()
def _eval(self, epoch: int):
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
loss_sum += loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
def _before_fit(self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
use_wandb: bool = False):
"""
Args:
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
"""
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.logger = logger
self.use_wandb = use_wandb
if use_wandb:
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
wandb.watch(self.model)
total_loss = 0
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
desc=f'steps',
disable=not is_rank_0())
for epoch in range(self.max_epochs):
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
total_loss += loss.item()
# gradient accumulation
if (batch_id + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and use_wandb:
wandb.log({
"loss": total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
total_loss = 0
step_bar.update()
# if batch_id % log_interval == 0:
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# wandb.log({"loss": loss.item()})
# process_bar.update()
# eval
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
loss_sum += loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
# epoch_bar.update()
self.total_loss = 0
self.no_epoch_bar = True
self.step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
desc=f'steps',
disable=not is_rank_0()
)

View File

@ -1,4 +1,3 @@
import functools
import warnings
from typing import Optional
@ -103,7 +102,7 @@ class ColossalAIStrategy(DDPStrategy):
# NOTE: dist should be initialized before calling get_current_device()
if stage == 3:
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision=precision,
@ -113,20 +112,20 @@ class ColossalAIStrategy(DDPStrategy):
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
# optim_config
**optim_kwargs)
else:
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
# optim_config
**optim_kwargs)
super().__init__(seed, plugin_initializer)

View File

@ -3,6 +3,33 @@ from typing import Any
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
class CycledDataLoader:
"""
Why do we need this class?
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
"""
def __init__(self,
dataloader: DataLoader,
) -> None:
self.dataloader = dataloader
self.count = 0
self.dataloader_iter = iter(dataloader)
def next(self):
self.count += 1
try:
return next(self.dataloader_iter)
except StopIteration:
self.count = 0
self.dataloader_iter = iter(self.dataloader)
return next(self.dataloader_iter)
def is_rank_0() -> bool:

View File

@ -171,9 +171,8 @@ Pretrain dataset: the pretrain dataset including the instruction and correspondi
- --pretrain_dataset: path of the ptx dataset, type=str, default=None
- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
- --num_episodes: num of episodes for training, type=int, default=10
- --max_epochs: max epochs for training in one episode, type=int, default=5
- --max_timesteps: max episodes in one batch, type=int, default=10
- --update_timesteps: timesteps to update, type=int, default=10
- --num_update_steps: number of steps to update policy per episode, type=int
- --num_collect_steps: number of steps to collect experience per episode, type=int
- --train_batch_size: batch size while training, type=int, default=8
- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
- --experience_batch_size: batch size to make experience, type=int, default=8

View File

@ -171,7 +171,6 @@ def main(args):
critic_optim,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=tokenize_fn,
@ -186,8 +185,8 @@ def main(args):
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
@ -215,9 +214,8 @@ if __name__ == '__main__':
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=2)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)

View File

@ -63,8 +63,8 @@ for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy $strategy --model $model \
--num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
--train_batch_size 2
done
done
@ -149,8 +149,8 @@ rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--strategy colossalai_zero2 --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'facebook/opt-350m' --model opt \
--rm_pretrain 'facebook/opt-350m' \
--rm_path ${BASE}/rm_ckpt_opt.pt \
@ -159,8 +159,8 @@ rm -rf ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--strategy colossalai_zero2 --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
@ -168,8 +168,8 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--strategy colossalai_gemini --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \

View File

@ -177,7 +177,6 @@ def main(args):
critic_optim,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
max_length=args.max_seq_len,
use_cache=True,
@ -192,8 +191,8 @@ def main(args):
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps)
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
@ -220,9 +219,8 @@ if __name__ == '__main__':
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)

View File

@ -1,13 +1,13 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_prompts.py \
--pretrain_dataset /path/to/data.json \
--prompt_dataset /path/to/data.json \
--strategy colossalai_zero2 \
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
--train_batch_size 2

View File

@ -178,12 +178,11 @@ def train(args):
optim=optim,
lr_scheduler=lr_scheduler,
loss_fn=loss_fn,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs)
trainer.fit()
trainer.fit(train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks

View File

@ -170,12 +170,13 @@ def train(args):
strategy=strategy,
optim=optim,
lr_scheduler=lr_scheduler,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps)
trainer.fit(logger=logger, use_wandb=args.use_wandb)
trainer.fit(train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
logger=logger,
use_wandb=args.use_wandb)
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)