reconstruct chat trainer and fix training script (#3588)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3592/head
Yuanchen 2023-04-18 16:44:03 +08:00 committed by GitHub
parent dac127d0ee
commit 1ec0d386a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 163 additions and 137 deletions

View File

@ -156,8 +156,10 @@ def main(args):
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts, random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)

View File

@ -149,8 +149,10 @@ def main(args):
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts, random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)

View File

@ -2,15 +2,10 @@ from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from coati.experience_maker import Experience, ExperienceMaker from coati.experience_maker import Experience
from coati.replay_buffer import ReplayBuffer
from torch import Tensor
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from .callbacks import Callback from .callbacks import Callback
from .strategies import Strategy from .strategies import Strategy
from .utils import is_rank_0
class Trainer(ABC): class Trainer(ABC):
@ -19,113 +14,28 @@ class Trainer(ABC):
Args: Args:
strategy (Strategy):the strategy to use for training strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating generate_kwargs (dict, optional): the kwargs to use while model generating
""" """
def __init__(self, def __init__(self,
strategy: Strategy, strategy: Strategy,
experience_maker: ExperienceMaker,
replay_buffer: ReplayBuffer,
experience_batch_size: int = 8,
max_epochs: int = 1, max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None, tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
**generate_kwargs) -> None: **generate_kwargs) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs self.max_epochs = max_epochs
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs self.generate_kwargs = generate_kwargs
self.sample_replay_buffer = sample_replay_buffer
self.dataloader_pin_memory = dataloader_pin_memory self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks self.callbacks = callbacks
@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
self.experience_maker.initial_model.to(torch.cuda.current_device())
self.experience_maker.reward_model.to(torch.cuda.current_device())
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
# TODO(ver217): maybe simplify these code using context # TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None: def _on_fit_start(self) -> None:
for callback in self.callbacks: for callback in self.callbacks:

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -7,12 +7,16 @@ from coati.models.base import Actor, Critic
from coati.models.generation_utils import update_model_kwargs_fn from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss from coati.models.loss import PolicyLoss, ValueLoss
from coati.replay_buffer import NaiveReplayBuffer from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from tqdm import tqdm
from .base import Trainer from .base import Trainer
from .callbacks import Callback from .callbacks import Callback
from .strategies import Strategy from .strategies import Strategy
from .utils import is_rank_0
class PPOTrainer(Trainer): class PPOTrainer(Trainer):
@ -33,6 +37,7 @@ class PPOTrainer(Trainer):
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process max_epochs (int, defaults to 1): the number of epochs of training process
@ -69,8 +74,13 @@ class PPOTrainer(Trainer):
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, super().__init__(strategy, max_epochs, tokenizer, dataloader_pin_memory, callbacks, **generate_kwargs)
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.sample_replay_buffer = sample_replay_buffer
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
@ -82,6 +92,81 @@ class PPOTrainer(Trainer):
self.actor_optim = actor_optim self.actor_optim = actor_optim
self.critic_optim = critic_optim self.critic_optim = critic_optim
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice(
indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(
self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch',
disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(
dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
self.experience_maker.initial_model.to(
torch.cuda.current_device())
self.experience_maker.reward_model.to(
torch.cuda.current_device())
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
def training_step(self, experience: Experience) -> Dict[str, float]: def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train() self.actor.train()
self.critic.train() self.critic.train()

View File

@ -1,6 +1,5 @@
from abc import ABC
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, List
import pandas as pd import pandas as pd
import torch import torch
@ -10,11 +9,13 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .callbacks import Callback
from .base import Trainer
from .strategies import Strategy from .strategies import Strategy
from .utils import is_rank_0 from .utils import is_rank_0
class RewardModelTrainer(ABC): class RewardModelTrainer(Trainer):
""" """
Trainer to use while training reward model. Trainer to use while training reward model.
@ -23,11 +24,12 @@ class RewardModelTrainer(ABC):
strategy (Strategy): the strategy to use for training strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training optim(Optimizer): the optimizer to use for training
loss_fn (callable): the loss function to use for training loss_fn (callable): the loss function to use for training
train_dataset (Dataset): the dataset to use for training train_dataloader (DataLoader): the dataloader to use for training
valid_dataset (Dataset): the dataset to use for validation valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataset (Dataset): the dataset to use for evaluation eval_dataloader (DataLoader): the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
""" """
def __init__( def __init__(
@ -36,25 +38,19 @@ class RewardModelTrainer(ABC):
strategy: Strategy, strategy: Strategy,
optim: Optimizer, optim: Optimizer,
loss_fn, loss_fn,
train_dataset: Dataset, train_dataloader: DataLoader,
valid_dataset: Dataset, valid_dataloader: DataLoader,
eval_dataset: Dataset, eval_dataloader: DataLoader,
batch_size: int = 1, batch_size: int = 1,
max_epochs: int = 1, max_epochs: int = 1,
callbacks: List[Callback] = [],
) -> None: ) -> None:
super().__init__() super().__init__(strategy, max_epochs, callbacks=callbacks)
self.strategy = strategy
self.epochs = max_epochs
train_sampler = None train_sampler = None
if dist.is_initialized() and dist.get_world_size() > 1: self.train_dataloader = train_dataloader
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) self.valid_dataloader = valid_dataloader
self.train_dataloader = DataLoader(train_dataset, self.eval_dataloader = eval_dataloader
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=batch_size)
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
self.model = strategy.setup_model(model) self.model = strategy.setup_model(model)
self.loss_fn = loss_fn self.loss_fn = loss_fn
@ -86,8 +82,8 @@ class RewardModelTrainer(ABC):
def fit(self): def fit(self):
time = datetime.now() time = datetime.now()
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs): for epoch in range(self.max_epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()), step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch, desc='Train step of epoch %d' % epoch,
disable=not is_rank_0()) disable=not is_rank_0())

View File

@ -1,7 +1,6 @@
import math import math
import time import time
from abc import ABC from typing import Optional, List
from typing import Optional
import loralib as lora import loralib as lora
import torch import torch
@ -19,11 +18,13 @@ from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .callbacks import Callback
from .base import Trainer
from .strategies import Strategy from .strategies import Strategy
from .utils import is_rank_0 from .utils import is_rank_0
class SFTTrainer(ABC): class SFTTrainer(Trainer):
""" """
Trainer to use while training reward model. Trainer to use while training reward model.
@ -35,6 +36,7 @@ class SFTTrainer(ABC):
eval_dataloader: the dataloader to use for evaluation eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
""" """
@ -48,10 +50,9 @@ class SFTTrainer(ABC):
batch_size: int = 1, batch_size: int = 1,
max_epochs: int = 2, max_epochs: int = 2,
accimulation_steps: int = 8, accimulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None: ) -> None:
super().__init__() super().__init__(strategy, max_epochs, callbacks=callbacks)
self.strategy = strategy
self.epochs = max_epochs
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
@ -62,7 +63,7 @@ class SFTTrainer(ABC):
self.accimulation_steps = accimulation_steps self.accimulation_steps = accimulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
max_steps = math.ceil(self.epochs * num_update_steps_per_epoch) max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
self.scheduler = get_scheduler("cosine", self.scheduler = get_scheduler("cosine",
self.optimizer, self.optimizer,
@ -74,10 +75,10 @@ class SFTTrainer(ABC):
wandb.watch(self.model) wandb.watch(self.model)
total_loss = 0 total_loss = 0
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0()) # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.epochs), step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs),
desc=f'steps', desc=f'steps',
disable=not is_rank_0()) disable=not is_rank_0())
for epoch in range(self.epochs): for epoch in range(self.max_epochs):
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0()) # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
# train # train
@ -148,7 +149,7 @@ class SFTTrainer(ABC):
loss_mean = loss_sum / num_seen loss_mean = loss_sum / num_seen
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
# epoch_bar.update() # epoch_bar.update()

View File

@ -114,8 +114,10 @@ def main(args):
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
callbacks=callbacks) callbacks=callbacks)
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts, random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
@ -136,7 +138,7 @@ if __name__ == '__main__':
default='naive') default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt') parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy')
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=50) parser.add_argument('--num_episodes', type=int, default=50)
parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument('--max_timesteps', type=int, default=10)

View File

@ -3,6 +3,7 @@ from random import randint
import loralib as lora import loralib as lora
import torch import torch
import torch.distributed as dist
from coati.dataset import HhRlhfDataset, RmStaticDataset from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss from coati.models import LogExpLoss, LogSigLoss
from coati.models.base import RewardModel from coati.models.base import RewardModel
@ -17,6 +18,8 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrat
from coati.utils import prepare_llama_tokenizer_and_embedding from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@ -120,13 +123,38 @@ def train(args):
else: else:
raise ValueError(f'Unsupported dataset "{args.dataset}"') raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
num_replicas=dist.get_world_size())
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
num_replicas=dist.get_world_size())
else:
train_sampler = None
valid_sampler = None
eval_sampler = None
train_dataloader = DataLoader(train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None),
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True)
trainer = RewardModelTrainer(model=model, trainer = RewardModelTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
loss_fn=loss_fn, loss_fn=loss_fn,
train_dataset=train_dataset, train_dataloader=train_dataloader,
valid_dataset=valid_dataset, valid_dataloader=valid_dataloader,
eval_dataset=eval_dataset, eval_dataloader=eval_dataloader,
batch_size=args.batch_size, batch_size=args.batch_size,
max_epochs=args.max_epochs) max_epochs=args.max_epochs)