diff --git a/applications/Chat/benchmarks/benchmark_gpt_dummy.py b/applications/Chat/benchmarks/benchmark_gpt_dummy.py index c0d8b1c37..e41ef239d 100644 --- a/applications/Chat/benchmarks/benchmark_gpt_dummy.py +++ b/applications/Chat/benchmarks/benchmark_gpt_dummy.py @@ -156,8 +156,10 @@ def main(args): eos_token_id=tokenizer.eos_token_id, callbacks=[performance_evaluator]) - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) - trainer.fit(random_prompts, + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device()) + random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool) + random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)] + trainer.fit(random_prompts, random_pretrain, num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 42df2e1f2..c79435ec6 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -149,8 +149,10 @@ def main(args): eos_token_id=tokenizer.eos_token_id, callbacks=[performance_evaluator]) - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) - trainer.fit(random_prompts, + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device()) + random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool) + random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)] + trainer.fit(random_prompts, random_pretrain, num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index 610bb5111..d67679949 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -2,15 +2,10 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union import torch -from coati.experience_maker import Experience, ExperienceMaker -from coati.replay_buffer import ReplayBuffer -from torch import Tensor -from torch.utils.data import DistributedSampler -from tqdm import tqdm +from coati.experience_maker import Experience from .callbacks import Callback from .strategies import Strategy -from .utils import is_rank_0 class Trainer(ABC): @@ -19,113 +14,28 @@ class Trainer(ABC): Args: strategy (Strategy):the strategy to use for training - experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer - replay_buffer (ReplayBuffer): the replay buffer to use for training - experience_batch_size (int, defaults to 8): the batch size to use for experience generation max_epochs (int, defaults to 1): the number of epochs of training process tokenizer (Callable, optional): the tokenizer to use for tokenizing the input - sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer - data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating """ def __init__(self, strategy: Strategy, - experience_maker: ExperienceMaker, - replay_buffer: ReplayBuffer, - experience_batch_size: int = 8, max_epochs: int = 1, tokenizer: Optional[Callable[[Any], dict]] = None, - sample_replay_buffer: bool = False, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: super().__init__() self.strategy = strategy - self.experience_maker = experience_maker - self.replay_buffer = replay_buffer - self.experience_batch_size = experience_batch_size self.max_epochs = max_epochs self.tokenizer = tokenizer self.generate_kwargs = generate_kwargs - self.sample_replay_buffer = sample_replay_buffer self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks - @abstractmethod - def training_step(self, experience: Experience) -> Dict[str, Any]: - pass - - def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: - if isinstance(inputs, Tensor): - return self.experience_maker.make_experience(inputs, **self.generate_kwargs) - elif isinstance(inputs, dict): - return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) - else: - raise ValueError(f'Unsupported input type "{type(inputs)}"') - - def _sample_prompts(self, prompts) -> list: - indices = list(range(len(prompts))) - sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False) - return [prompts[i] for i in sampled_indices] - - def _learn(self): - # replay buffer may be empty at first, we should rebuild at each training - if not self.sample_replay_buffer: - dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) - device = torch.cuda.current_device() - if self.sample_replay_buffer: - pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for _ in pbar: - experience = self.replay_buffer.sample() - metrics = self.training_step(experience) - pbar.set_postfix(metrics) - else: - for epoch in range(self.max_epochs): - self._on_learn_epoch_start(epoch) - if isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(epoch) - pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) - for experience in pbar: - self._on_learn_batch_start() - experience.to_device(device) - metrics = self.training_step(experience) - self._on_learn_batch_end(metrics, experience) - pbar.set_postfix(metrics) - self._on_learn_epoch_end(epoch) - - def fit(self, - prompt_dataloader, - pretrain_dataloader, - num_episodes: int = 50000, - max_timesteps: int = 500, - update_timesteps: int = 5000) -> None: - time = 0 - self.pretrain_dataloader = pretrain_dataloader - self.prompt_dataloader = prompt_dataloader - self._on_fit_start() - for episode in range(num_episodes): - self._on_episode_start(episode) - for timestep in tqdm(range(max_timesteps), - desc=f'Episode [{episode+1}/{num_episodes}]', - disable=not is_rank_0()): - time += 1 - prompts = next(iter(self.prompt_dataloader)) - self._on_make_experience_start() - self.experience_maker.initial_model.to(torch.cuda.current_device()) - self.experience_maker.reward_model.to(torch.cuda.current_device()) - experience = self._make_experience(prompts) - self._on_make_experience_end(experience) - self.replay_buffer.append(experience) - if time % update_timesteps == 0: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') - self._learn() - self.replay_buffer.clear() - self._on_episode_end(episode) - self._on_fit_end() - # TODO(ver217): maybe simplify these code using context def _on_fit_start(self) -> None: for callback in self.callbacks: diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index d58e437e6..008a6aea8 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn as nn @@ -7,12 +7,16 @@ from coati.models.base import Actor, Critic from coati.models.generation_utils import update_model_kwargs_fn from coati.models.loss import PolicyLoss, ValueLoss from coati.replay_buffer import NaiveReplayBuffer +from torch import Tensor from torch.optim import Optimizer +from torch.utils.data import DistributedSampler from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from tqdm import tqdm from .base import Trainer from .callbacks import Callback from .strategies import Strategy +from .utils import is_rank_0 class PPOTrainer(Trainer): @@ -33,6 +37,7 @@ class PPOTrainer(Trainer): buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu eps_clip (float, defaults to 0.2): the clip coefficient of policy loss vf_coef (float, defaults to 1.0): the coefficient of value loss + ptx_coef (float, defaults to 0.9): the coefficient of ptx loss value_clip (float, defaults to 0.4): the clip coefficient of value loss experience_batch_size (int, defaults to 8): the batch size to use for experience generation max_epochs (int, defaults to 1): the number of epochs of training process @@ -69,8 +74,13 @@ class PPOTrainer(Trainer): experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, - sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs) + super().__init__(strategy, max_epochs, tokenizer, dataloader_pin_memory, callbacks, **generate_kwargs) + + self.experience_maker = experience_maker + self.replay_buffer = replay_buffer + self.experience_batch_size = experience_batch_size + self.sample_replay_buffer = sample_replay_buffer + self.actor = actor self.critic = critic @@ -82,6 +92,81 @@ class PPOTrainer(Trainer): self.actor_optim = actor_optim self.critic_optim = critic_optim + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: + if isinstance(inputs, Tensor): + return self.experience_maker.make_experience(inputs, **self.generate_kwargs) + elif isinstance(inputs, dict): + return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + else: + raise ValueError(f'Unsupported input type "{type(inputs)}"') + + def _sample_prompts(self, prompts) -> list: + indices = list(range(len(prompts))) + sampled_indices = self.strategy.experience_sampler.choice( + indices, self.experience_batch_size, replace=False) + return [prompts[i] for i in sampled_indices] + + def _learn(self): + # replay buffer may be empty at first, we should rebuild at each training + if not self.sample_replay_buffer: + dataloader = self.strategy.setup_dataloader( + self.replay_buffer, self.dataloader_pin_memory) + device = torch.cuda.current_device() + if self.sample_replay_buffer: + pbar = tqdm(range(self.max_epochs), desc='Train epoch', + disable=not is_rank_0()) + for _ in pbar: + experience = self.replay_buffer.sample() + metrics = self.training_step(experience) + pbar.set_postfix(metrics) + else: + for epoch in range(self.max_epochs): + self._on_learn_epoch_start(epoch) + if isinstance(dataloader.sampler, DistributedSampler): + dataloader.sampler.set_epoch(epoch) + pbar = tqdm( + dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(device) + metrics = self.training_step(experience) + self._on_learn_batch_end(metrics, experience) + pbar.set_postfix(metrics) + self._on_learn_epoch_end(epoch) + + def fit(self, + prompt_dataloader, + pretrain_dataloader, + num_episodes: int = 50000, + max_timesteps: int = 500, + update_timesteps: int = 5000) -> None: + time = 0 + self.pretrain_dataloader = pretrain_dataloader + self.prompt_dataloader = prompt_dataloader + self._on_fit_start() + for episode in range(num_episodes): + self._on_episode_start(episode) + for timestep in tqdm(range(max_timesteps), + desc=f'Episode [{episode+1}/{num_episodes}]', + disable=not is_rank_0()): + time += 1 + prompts = next(iter(self.prompt_dataloader)) + self._on_make_experience_start() + self.experience_maker.initial_model.to( + torch.cuda.current_device()) + self.experience_maker.reward_model.to( + torch.cuda.current_device()) + experience = self._make_experience(prompts) + self._on_make_experience_end(experience) + self.replay_buffer.append(experience) + if time % update_timesteps == 0: + self.experience_maker.initial_model.to('cpu') + self.experience_maker.reward_model.to('cpu') + self._learn() + self.replay_buffer.clear() + self._on_episode_end(episode) + self._on_fit_end() + def training_step(self, experience: Experience) -> Dict[str, float]: self.actor.train() self.critic.train() diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index 0cf09b041..ed6720abc 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -1,6 +1,5 @@ -from abc import ABC from datetime import datetime -from typing import Optional +from typing import Optional, List import pandas as pd import torch @@ -10,11 +9,13 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler from tqdm import tqdm from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from .callbacks import Callback +from .base import Trainer from .strategies import Strategy from .utils import is_rank_0 -class RewardModelTrainer(ABC): +class RewardModelTrainer(Trainer): """ Trainer to use while training reward model. @@ -23,11 +24,12 @@ class RewardModelTrainer(ABC): strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training loss_fn (callable): the loss function to use for training - train_dataset (Dataset): the dataset to use for training - valid_dataset (Dataset): the dataset to use for validation - eval_dataset (Dataset): the dataset to use for evaluation + train_dataloader (DataLoader): the dataloader to use for training + valid_dataloader (DataLoader): the dataloader to use for validation + eval_dataloader (DataLoader): the dataloader to use for evaluation batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train + callbacks (List[Callback], defaults to []): the callbacks to call during training process """ def __init__( @@ -36,25 +38,19 @@ class RewardModelTrainer(ABC): strategy: Strategy, optim: Optimizer, loss_fn, - train_dataset: Dataset, - valid_dataset: Dataset, - eval_dataset: Dataset, + train_dataloader: DataLoader, + valid_dataloader: DataLoader, + eval_dataloader: DataLoader, batch_size: int = 1, max_epochs: int = 1, + callbacks: List[Callback] = [], ) -> None: - super().__init__() - self.strategy = strategy - self.epochs = max_epochs + super().__init__(strategy, max_epochs, callbacks=callbacks) train_sampler = None - if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) - self.train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=batch_size) - self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True) - self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True) + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.eval_dataloader = eval_dataloader self.model = strategy.setup_model(model) self.loss_fn = loss_fn @@ -86,8 +82,8 @@ class RewardModelTrainer(ABC): def fit(self): time = datetime.now() - epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) - for epoch in range(self.epochs): + epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) + for epoch in range(self.max_epochs): step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch, disable=not is_rank_0()) diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index f380cbf06..350553108 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -1,7 +1,6 @@ import math import time -from abc import ABC -from typing import Optional +from typing import Optional, List import loralib as lora import torch @@ -19,11 +18,13 @@ from transformers.trainer import get_scheduler from colossalai.logging import get_dist_logger +from .callbacks import Callback +from .base import Trainer from .strategies import Strategy from .utils import is_rank_0 -class SFTTrainer(ABC): +class SFTTrainer(Trainer): """ Trainer to use while training reward model. @@ -35,6 +36,7 @@ class SFTTrainer(ABC): eval_dataloader: the dataloader to use for evaluation batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train + callbacks (List[Callback], defaults to []): the callbacks to call during training process optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer """ @@ -48,10 +50,9 @@ class SFTTrainer(ABC): batch_size: int = 1, max_epochs: int = 2, accimulation_steps: int = 8, + callbacks: List[Callback] = [], ) -> None: - super().__init__() - self.strategy = strategy - self.epochs = max_epochs + super().__init__(strategy, max_epochs, callbacks=callbacks) self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader @@ -62,7 +63,7 @@ class SFTTrainer(ABC): self.accimulation_steps = accimulation_steps num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps - max_steps = math.ceil(self.epochs * num_update_steps_per_epoch) + max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch) self.scheduler = get_scheduler("cosine", self.optimizer, @@ -74,10 +75,10 @@ class SFTTrainer(ABC): wandb.watch(self.model) total_loss = 0 # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0()) - step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.epochs), + step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs), desc=f'steps', disable=not is_rank_0()) - for epoch in range(self.epochs): + for epoch in range(self.max_epochs): # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0()) # train @@ -148,7 +149,7 @@ class SFTTrainer(ABC): loss_mean = loss_sum / num_seen if dist.get_rank() == 0: - logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') + logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') # epoch_bar.update() diff --git a/applications/Chat/examples/train_dummy.py b/applications/Chat/examples/train_dummy.py index 4ac7ace44..5f34c80f0 100644 --- a/applications/Chat/examples/train_dummy.py +++ b/applications/Chat/examples/train_dummy.py @@ -114,8 +114,10 @@ def main(args): eos_token_id=tokenizer.eos_token_id, callbacks=callbacks) - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) - trainer.fit(random_prompts, + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device()) + random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool) + random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)] + trainer.fit(random_prompts, random_pretrain, num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) @@ -136,7 +138,7 @@ if __name__ == '__main__': default='naive') parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt') + parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy') parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--num_episodes', type=int, default=50) parser.add_argument('--max_timesteps', type=int, default=10) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index aa1b51dea..6a788a891 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -3,6 +3,7 @@ from random import randint import loralib as lora import torch +import torch.distributed as dist from coati.dataset import HhRlhfDataset, RmStaticDataset from coati.models import LogExpLoss, LogSigLoss from coati.models.base import RewardModel @@ -17,6 +18,8 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrat from coati.utils import prepare_llama_tokenizer_and_embedding from datasets import load_dataset from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -120,13 +123,38 @@ def train(args): else: raise ValueError(f'Unsupported dataset "{args.dataset}"') + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + valid_sampler = None + eval_sampler = None + + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True) + + valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None), + sampler=valid_sampler, + batch_size=args.batch_size, pin_memory=True) + + eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), + sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True) + trainer = RewardModelTrainer(model=model, strategy=strategy, optim=optim, loss_fn=loss_fn, - train_dataset=train_dataset, - valid_dataset=valid_dataset, - eval_dataset=eval_dataset, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + eval_dataloader=eval_dataloader, batch_size=args.batch_size, max_epochs=args.max_epochs)