diff --git a/applications/Chat/README.md b/applications/Chat/README.md index 29cd581d7..082cbb22b 100644 --- a/applications/Chat/README.md +++ b/applications/Chat/README.md @@ -83,7 +83,7 @@ More details can be found in the latest news.

-> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --max_timesteps 1 --update_timesteps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32 +> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32 ## Install diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index dea7ebc60..39f2f28ec 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -137,6 +137,12 @@ def main(args): (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) + dataloader = DataLoader(random_prompts, + batch_size=args.experience_batch_size, + shuffle=True, + collate_fn=preprocess_batch) + trainer = PPOTrainer(strategy, actor, critic, @@ -145,7 +151,6 @@ def main(args): actor_optim, critic_optim, ptx_coef=0, - max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, offload_inference_models=args.offload_inference_models, max_length=512, @@ -157,17 +162,11 @@ def main(args): eos_token_id=tokenizer.eos_token_id, callbacks=[performance_evaluator]) - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) - dataloader = DataLoader(random_prompts, - batch_size=args.experience_batch_size, - shuffle=True, - collate_fn=preprocess_batch) - - trainer.fit(dataloader, - None, + trainer.fit(prompt_dataloader=dataloader, + pretrain_dataloader=None, num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps) print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') @@ -183,9 +182,8 @@ if __name__ == '__main__': ], default='ddp') parser.add_argument('--num_episodes', type=int, default=3) - parser.add_argument('--max_timesteps', type=int, default=8) - parser.add_argument('--update_timesteps', type=int, default=8) - parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--num_collect_steps', type=int, default=8) + parser.add_argument('--num_update_steps', type=int, default=1) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--lora_rank', type=int, default=0) diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py index 525b57bf2..86142361f 100644 --- a/applications/Chat/coati/trainer/__init__.py +++ b/applications/Chat/coati/trainer/__init__.py @@ -1,6 +1,10 @@ -from .base import Trainer +from .base import OnPolicyTrainer, SLTrainer from .ppo import PPOTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer -__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer'] +__all__ = [ + 'SLTrainer', 'OnPolicyTrainer', + 'RewardModelTrainer', 'SFTTrainer', + 'PPOTrainer' +] diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index ac3a878be..13571cdcc 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -1,54 +1,108 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union +from contextlib import contextmanager +from typing import List -import torch +import torch.nn as nn +import tqdm from coati.experience_maker import Experience +from coati.replay_buffer import NaiveReplayBuffer +from torch.optim import Optimizer +from torch.utils.data import DataLoader from .callbacks import Callback from .strategies import Strategy +from .utils import CycledDataLoader, is_rank_0 -class Trainer(ABC): +class SLTrainer(ABC): """ - Base class for rlhf trainers. + Base class for supervised learning trainers. Args: strategy (Strategy):the strategy to use for training max_epochs (int, defaults to 1): the number of epochs of training process - dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader - callbacks (List[Callback], defaults to []): the callbacks to call during training process - generate_kwargs (dict, optional): the kwargs to use while model generating + model (nn.Module): the model to train + optim (Optimizer): the optimizer to use for training """ def __init__(self, strategy: Strategy, - max_epochs: int = 1, - dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs) -> None: + max_epochs: int, + model: nn.Module, + optimizer: Optimizer, + ) -> None: super().__init__() self.strategy = strategy self.max_epochs = max_epochs - self.generate_kwargs = generate_kwargs + self.model = model + self.optimizer = optimizer + + @abstractmethod + def _train(self, epoch): + raise NotImplementedError() + + @abstractmethod + def _eval(self, epoch): + raise NotImplementedError() + + def _before_fit(self): + self.no_epoch_bar = False + + def fit(self, *args, **kwargs): + self._before_fit(*args, **kwargs) + for epoch in tqdm.trange(self.max_epochs, + desc="Epochs", + disable=not is_rank_0() or self.no_epoch_bar + ): + self._train(epoch) + self._eval(epoch) + + +class OnPolicyTrainer(ABC): + """ + Base class for on-policy rl trainers, e.g. PPO. + + Args: + strategy (Strategy):the strategy to use for training + buffer (NaiveReplayBuffer): the buffer to collect experiences + sample_buffer (bool, defaults to False): whether to sample from buffer + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + """ + + def __init__(self, + strategy: Strategy, + buffer: NaiveReplayBuffer, + sample_buffer: bool, + dataloader_pin_memory: bool, + callbacks: List[Callback] = [] + ) -> None: + super().__init__() + self.strategy = strategy + self.buffer = buffer + self.sample_buffer = sample_buffer self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks - # TODO(ver217): maybe simplify these code using context - def _on_fit_start(self) -> None: + @contextmanager + def _fit_ctx(self) -> None: for callback in self.callbacks: callback.on_fit_start() + try: + yield + finally: + for callback in self.callbacks: + callback.on_fit_end() - def _on_fit_end(self) -> None: - for callback in self.callbacks: - callback.on_fit_end() - - def _on_episode_start(self, episode: int) -> None: + @contextmanager + def _episode_ctx(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_start(episode) - - def _on_episode_end(self, episode: int) -> None: - for callback in self.callbacks: - callback.on_episode_end(episode) + try: + yield + finally: + for callback in self.callbacks: + callback.on_episode_end(episode) def _on_make_experience_start(self) -> None: for callback in self.callbacks: @@ -73,3 +127,71 @@ class Trainer(ABC): def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: for callback in self.callbacks: callback.on_learn_batch_end(metrics, experience) + + @abstractmethod + def _make_experience(self, collect_step: int): + """ + Implement this method to make experience. + """ + raise NotImplementedError() + + @abstractmethod + def _learn(self, update_step: int): + """ + Implement this method to learn from experience, either + sample from buffer or transform buffer into dataloader. + """ + raise NotImplementedError() + + def _collect_phase(self, collect_step: int): + self._on_make_experience_start() + experience = self._make_experience(collect_step) + self._on_make_experience_end(experience) + self.buffer.append(experience) + + def _update_phase(self, update_step: int): + self._on_learn_epoch_start(update_step) + self._learn(update_step) + self._on_learn_epoch_end(update_step) + + def fit(self, + prompt_dataloader: DataLoader, + pretrain_dataloader: DataLoader, + num_episodes: int, + num_collect_steps: int, + num_update_steps: int, + ): + """ + The main training loop of on-policy rl trainers. + + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + num_episodes (int): the number of episodes to train + num_collect_steps (int): the number of collect steps per episode + num_update_steps (int): the number of update steps per episode + """ + self.prompt_dataloader = CycledDataLoader(prompt_dataloader) + self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) + + with self._fit_ctx(): + for episode in tqdm.trange(num_episodes, + desc="Episodes", + disable=not is_rank_0()): + with self._episode_ctx(episode): + for collect_step in tqdm.trange(num_collect_steps, + desc="Collect steps", + disable=not is_rank_0()): + self._collect_phase(collect_step) + if not self.sample_buffer: + # HACK(cwher): according to the design of boost API, dataloader should also be boosted, + # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. + # I only call strategy.setup_dataloader() to setup dataloader. + self.dataloader = self.strategy.setup_dataloader(self.buffer, + self.dataloader_pin_memory) + for update_step in tqdm.trange(num_update_steps, + desc="Update steps", + disable=not is_rank_0()): + self._update_phase(update_step) + # NOTE: this is for on-policy algorithms + self.buffer.clear() diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index cfb18e2ae..451abe2a7 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -1,6 +1,5 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Dict, List -import torch import torch.nn as nn from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic, get_base_model @@ -9,19 +8,32 @@ from coati.models.utils import calc_action_log_probs from coati.replay_buffer import NaiveReplayBuffer from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DistributedSampler +from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm -from transformers.tokenization_utils_base import PreTrainedTokenizerBase from colossalai.utils import get_current_device -from .base import Trainer +from .base import OnPolicyTrainer from .callbacks import Callback from .strategies import ColossalAIStrategy, Strategy from .utils import is_rank_0, to_device -class PPOTrainer(Trainer): +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: + unwrapper_model = strategy.unwrap_model(actor) + hf_model = get_base_model(unwrapper_model) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'): + new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation + + return new_kwargs + + +class PPOTrainer(OnPolicyTrainer): """ Trainer for PPO algorithm. @@ -35,14 +47,13 @@ class PPOTrainer(Trainer): critic_optim (Optimizer): the optimizer to use for critic model kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss train_batch_size (int, defaults to 8): the batch size to use for training - buffer_limit (int, defaults to 0): the max_size limitation of replay buffer - buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + buffer_limit (int, defaults to 0): the max_size limitation of buffer + buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu eps_clip (float, defaults to 0.2): the clip coefficient of policy loss vf_coef (float, defaults to 1.0): the coefficient of value loss ptx_coef (float, defaults to 0.9): the coefficient of ptx loss value_clip (float, defaults to 0.4): the clip coefficient of value loss - max_epochs (int, defaults to 1): the number of epochs of training process - sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer + sample_buffer (bool, defaults to False): whether to sample from buffer dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process @@ -65,25 +76,26 @@ class PPOTrainer(Trainer): eps_clip: float = 0.2, vf_coef: float = 1.0, value_clip: float = 0.4, - max_epochs: int = 1, - sample_replay_buffer: bool = False, + sample_buffer: bool = False, dataloader_pin_memory: bool = True, offload_inference_models: bool = True, callbacks: List[Callback] = [], - **generate_kwargs) -> None: + **generate_kwargs + ) -> None: if isinstance(strategy, ColossalAIStrategy): from colossalai.booster.plugin import GeminiPlugin assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \ "GeminiPlugin is not compatible with manual model.to('cpu')" - experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) - replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) - generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs) + buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + super().__init__( + strategy, buffer, + sample_buffer, dataloader_pin_memory, + callbacks + ) - self.experience_maker = experience_maker - self.replay_buffer = replay_buffer - self.sample_replay_buffer = sample_replay_buffer + self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) + self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) self.offload_inference_models = offload_inference_models self.actor = actor @@ -99,76 +111,20 @@ class PPOTrainer(Trainer): self.device = get_current_device() - def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: - if isinstance(inputs, Tensor): - return self.experience_maker.make_experience(inputs, **self.generate_kwargs) - elif isinstance(inputs, dict): - return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + def _make_experience(self, collect_step: int) -> Experience: + prompts = self.prompt_dataloader.next() + if self.offload_inference_models: + # TODO(ver217): this may be controlled by strategy if they are prepared by strategy + self.experience_maker.initial_model.to(self.device) + self.experience_maker.reward_model.to(self.device) + if isinstance(prompts, Tensor): + return self.experience_maker.make_experience(prompts, **self.generate_kwargs) + elif isinstance(prompts, dict): + return self.experience_maker.make_experience(**prompts, **self.generate_kwargs) else: - raise ValueError(f'Unsupported input type "{type(inputs)}"') + raise ValueError(f'Unsupported input type "{type(prompts)}"') - def _learn(self): - # replay buffer may be empty at first, we should rebuild at each training - if not self.sample_replay_buffer: - # HACK(cwher): according to the design of boost API, dataloader should also be boosted, - # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. - dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) - if self.sample_replay_buffer: - pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for _ in pbar: - experience = self.replay_buffer.sample() - experience.to_device(self.device) - metrics = self.training_step(experience) - pbar.set_postfix(metrics) - else: - for epoch in range(self.max_epochs): - self._on_learn_epoch_start(epoch) - if isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(epoch) - pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) - for experience in pbar: - self._on_learn_batch_start() - experience.to_device(self.device) - metrics = self.training_step(experience) - self._on_learn_batch_end(metrics, experience) - pbar.set_postfix(metrics) - self._on_learn_epoch_end(epoch) - - def fit(self, - prompt_dataloader, - pretrain_dataloader, - num_episodes: int = 50000, - max_timesteps: int = 500, - update_timesteps: int = 5000) -> None: - time = 0 - self.pretrain_dataloader = pretrain_dataloader - self.prompt_dataloader = prompt_dataloader - self._on_fit_start() - for episode in range(num_episodes): - self._on_episode_start(episode) - for timestep in tqdm(range(max_timesteps), - desc=f'Episode [{episode+1}/{num_episodes}]', - disable=not is_rank_0()): - time += 1 - prompts = next(iter(self.prompt_dataloader)) - self._on_make_experience_start() - if self.offload_inference_models: - # TODO(ver217): this may be controlled by strategy if they are prepared by strategy - self.experience_maker.initial_model.to(self.device) - self.experience_maker.reward_model.to(self.device) - experience = self._make_experience(prompts) - self._on_make_experience_end(experience) - self.replay_buffer.append(experience) - if time % update_timesteps == 0: - if self.offload_inference_models: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') - self._learn() - self.replay_buffer.clear() - self._on_episode_end(episode) - self._on_fit_end() - - def training_step(self, experience: Experience) -> Dict[str, float]: + def _training_step(self, experience: Experience) -> Dict[str, float]: self.actor.train() self.critic.train() # policy loss @@ -182,7 +138,7 @@ class PPOTrainer(Trainer): # ptx loss if self.ptx_coef != 0: - batch = next(iter(self.pretrain_dataloader)) + batch = self.pretrain_dataloader.next() batch = to_device(batch, self.device) ptx_log_probs = self.actor(batch['input_ids'], attention_mask=batch['attention_mask'])['logits'] @@ -208,16 +164,29 @@ class PPOTrainer(Trainer): return {'reward': experience.reward.mean().item()} + def _learn(self, update_step: int): + if self.offload_inference_models: + self.experience_maker.initial_model.to('cpu') + self.experience_maker.reward_model.to('cpu') -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: - unwrapper_model = strategy.unwrap_model(actor) - hf_model = get_base_model(unwrapper_model) - new_kwargs = {**generate_kwargs} - # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation - - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation - - return new_kwargs + # buffer may be empty at first, we should rebuild at each training + if self.sample_buffer: + experience = self.buffer.sample() + self._on_learn_batch_start() + experience.to_device(self.device) + metrics = self._training_step(experience) + self._on_learn_batch_end(metrics, experience) + else: + if isinstance(self.dataloader.sampler, DistributedSampler): + self.dataloader.sampler.set_epoch(update_step) + pbar = tqdm( + self.dataloader, + desc=f'Train epoch [{update_step + 1}]', + disable=not is_rank_0() + ) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(self.device) + metrics = self._training_step(experience) + self._on_learn_batch_end(metrics, experience) + pbar.set_postfix(metrics) diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index 316eded7e..54a5d0f40 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -1,20 +1,19 @@ from datetime import datetime -from typing import Callable, List +from typing import Callable import pandas as pd import torch +import tqdm from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import tqdm -from .base import Trainer -from .callbacks import Callback +from .base import SLTrainer from .strategies import Strategy from .utils import is_rank_0 -class RewardModelTrainer(Trainer): +class RewardModelTrainer(SLTrainer): """ Trainer to use while training reward model. @@ -24,12 +23,7 @@ class RewardModelTrainer(Trainer): optim (Optimizer): the optimizer to use for training lr_scheduler (_LRScheduler): the lr scheduler to use for training loss_fn (callable): the loss function to use for training - train_dataloader (DataLoader): the dataloader to use for training - valid_dataloader (DataLoader): the dataloader to use for validation - eval_dataloader (DataLoader): the dataloader to use for evaluation - batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train - callbacks (List[Callback], defaults to []): the callbacks to call during training process """ def __init__( @@ -39,87 +33,79 @@ class RewardModelTrainer(Trainer): optim: Optimizer, lr_scheduler: _LRScheduler, loss_fn: Callable, - train_dataloader: DataLoader, - valid_dataloader: DataLoader, - eval_dataloader: DataLoader, max_epochs: int = 1, - callbacks: List[Callback] = [], ) -> None: - super().__init__(strategy, max_epochs, callbacks=callbacks) + super().__init__(strategy, max_epochs, model, optim) + + self.loss_fn = loss_fn + self.scheduler = lr_scheduler + + def _eval(self, epoch): + if self.eval_dataloader is not None: + self.model.eval() + dist, on, cnt = 0, 0, 0 + with torch.no_grad(): + for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + for i in range(len(chosen_reward)): + cnt += 1 + if chosen_reward[i] > reject_reward[i]: + on += 1 + dist += (chosen_reward - reject_reward).mean().item() + self.dist = dist / len(self.eval_dataloader) + self.acc = on / cnt + + if is_rank_0(): + log = pd.DataFrame( + [[(epoch + 1) * len(self.train_dataloader), + self.loss.item(), self.dist, self.acc]], + columns=['step', 'loss', 'dist', 'acc'] + ) + log.to_csv('log.csv', mode='a', header=False, index=False) + + def _train(self, epoch): + self.model.train() + step_bar = tqdm.trange( + len(self.train_dataloader), + desc='Train step of epoch %d' % epoch, + disable=not is_rank_0() + ) + cnt = 0 + for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + self.loss = self.loss_fn(chosen_reward, reject_reward) + self.strategy.backward(self.loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + cnt += 1 + if cnt % 100 == 0: + self.scheduler.step() + step_bar.update() + step_bar.close() + + def _before_fit(self, + train_dataloader: DataLoader, + valid_dataloader: DataLoader, + eval_dataloader: DataLoader): + """ + Args: + train_dataloader (DataLoader): the dataloader to use for training + valid_dataloader (DataLoader): the dataloader to use for validation + eval_dataloader (DataLoader): the dataloader to use for evaluation + """ + super()._before_fit() + self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') self.train_dataloader = train_dataloader self.valid_dataloader = valid_dataloader self.eval_dataloader = eval_dataloader - - self.model = model - self.loss_fn = loss_fn - self.optimizer = optim - self.scheduler = lr_scheduler - - def eval_acc(self, dataloader): - dist = 0 - on = 0 - cnt = 0 - self.model.eval() - with torch.no_grad(): - for chosen_ids, c_mask, reject_ids, r_mask in dataloader: - chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) - c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) - reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) - r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - for i in range(len(chosen_reward)): - cnt += 1 - if chosen_reward[i] > reject_reward[i]: - on += 1 - dist += (chosen_reward - reject_reward).mean().item() - dist_mean = dist / len(dataloader) - acc = on / cnt - self.model.train() - return dist_mean, acc - - def fit(self): - time = datetime.now() - epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for epoch in range(self.max_epochs): - step_bar = tqdm(range(self.train_dataloader.__len__()), - desc='Train step of epoch %d' % epoch, - disable=not is_rank_0()) - # train - self.model.train() - cnt = 0 - acc = 0 - dist = 0 - for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: - chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) - c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) - reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) - r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - loss = self.loss_fn(chosen_reward, reject_reward) - self.strategy.backward(loss, self.model, self.optimizer) - self.strategy.optimizer_step(self.optimizer) - self.optimizer.zero_grad() - cnt += 1 - if cnt == 100: - self.scheduler.step() - dist, acc = self.eval_acc(self.valid_dataloader) - cnt = 0 - if is_rank_0(): - log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], - columns=['step', 'loss', 'dist', 'acc']) - log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False) - step_bar.update() - step_bar.set_postfix({'dist': dist, 'acc': acc}) - - # eval - dist, acc = self.eval_acc(self.eval_dataloader) - if is_rank_0(): - log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], - columns=['step', 'loss', 'dist', 'acc']) - log.to_csv('log.csv', mode='a', header=False, index=False) - epoch_bar.update() - step_bar.set_postfix({'dist': dist, 'acc': acc}) - step_bar.close() diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index da223f1f3..12c51d7a8 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -1,21 +1,22 @@ import time -from typing import List +from typing import Optional import torch import torch.distributed as dist +import tqdm import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import tqdm -from .base import Trainer -from .callbacks import Callback +from colossalai.logging import DistributedLogger + +from .base import SLTrainer from .strategies import ColossalAIStrategy, Strategy from .utils import is_rank_0, to_device -class SFTTrainer(Trainer): +class SFTTrainer(SLTrainer): """ Trainer to use while training reward model. @@ -23,12 +24,9 @@ class SFTTrainer(Trainer): model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training - train_dataloader: the dataloader to use for training - eval_dataloader: the dataloader to use for evaluation - batch_size (int, defaults to 1): the batch size while training + lr_scheduler(_LRScheduler): the lr scheduler to use for training max_epochs (int, defaults to 2): the number of epochs to train - callbacks (List[Callback], defaults to []): the callbacks to call during training process - optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer + accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients """ def __init__( @@ -37,95 +35,92 @@ class SFTTrainer(Trainer): strategy: Strategy, optim: Optimizer, lr_scheduler: _LRScheduler, - train_dataloader: DataLoader, - eval_dataloader: DataLoader = None, max_epochs: int = 2, accumulation_steps: int = 8, - callbacks: List[Callback] = [], ) -> None: if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy): from colossalai.booster.plugin import GeminiPlugin assert not isinstance(strategy.plugin, GeminiPlugin), \ "Accumulation steps are not supported in stage 3 of ColossalAI" - super().__init__(strategy, max_epochs, callbacks=callbacks) - self.train_dataloader = train_dataloader - self.eval_dataloader = eval_dataloader - self.model = model - self.optimizer = optim + + super().__init__(strategy, max_epochs, model, optim) self.accumulation_steps = accumulation_steps - self.scheduler = lr_scheduler - def fit(self, logger, use_wandb: bool = False): + def _train(self, epoch: int): + self.model.train() + for batch_id, batch in enumerate(self.train_dataloader): + + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model(batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"]) + + loss = outputs.loss + loss = loss / self.accumulation_steps + + self.strategy.backward(loss, self.model, self.optimizer) + + self.total_loss += loss.item() + + # gradient accumulation + if (batch_id + 1) % self.accumulation_steps == 0: + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + self.scheduler.step() + if is_rank_0() and self.use_wandb: + wandb.log({ + "loss": self.total_loss / self.accumulation_steps, + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id + }) + self.total_loss = 0 + self.step_bar.update() + + def _eval(self, epoch: int): + if self.eval_dataloader is not None: + self.model.eval() + with torch.no_grad(): + loss_sum, num_seen = 0, 0 + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model(batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"]) + loss = outputs.loss + + loss_sum += loss.item() + num_seen += batch["input_ids"].size(0) + + loss_mean = loss_sum / num_seen + if dist.get_rank() == 0: + self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') + + def _before_fit(self, + train_dataloader: DataLoader, + eval_dataloader: Optional[DataLoader] = None, + logger: Optional[DistributedLogger] = None, + use_wandb: bool = False): + """ + Args: + train_dataloader: the dataloader to use for training + eval_dataloader: the dataloader to use for evaluation + """ + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + + self.logger = logger + self.use_wandb = use_wandb if use_wandb: wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) wandb.watch(self.model) - total_loss = 0 - # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0()) - step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs), - desc=f'steps', - disable=not is_rank_0()) - for epoch in range(self.max_epochs): - # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0()) - # train - self.model.train() - for batch_id, batch in enumerate(self.train_dataloader): - - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) - - loss = outputs.loss - - if loss >= 2.5 and is_rank_0(): - logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}") - - loss = loss / self.accumulation_steps - - self.strategy.backward(loss, self.model, self.optimizer) - - total_loss += loss.item() - - # gradient accumulation - if (batch_id + 1) % self.accumulation_steps == 0: - self.strategy.optimizer_step(self.optimizer) - self.optimizer.zero_grad() - self.scheduler.step() - if is_rank_0() and use_wandb: - wandb.log({ - "loss": total_loss / self.accumulation_steps, - "lr": self.scheduler.get_last_lr()[0], - "epoch": epoch, - "batch_id": batch_id - }) - total_loss = 0 - step_bar.update() - - # if batch_id % log_interval == 0: - # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') - # wandb.log({"loss": loss.item()}) - - # process_bar.update() - - # eval - if self.eval_dataloader is not None: - self.model.eval() - with torch.no_grad(): - loss_sum = 0 - num_seen = 0 - for batch in self.eval_dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"]) - loss = outputs.loss - - loss_sum += loss.item() - num_seen += batch["input_ids"].size(0) - - loss_mean = loss_sum / num_seen - if dist.get_rank() == 0: - logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') - - # epoch_bar.update() + self.total_loss = 0 + self.no_epoch_bar = True + self.step_bar = tqdm.trange( + len(self.train_dataloader) // self.accumulation_steps * self.max_epochs, + desc=f'steps', + disable=not is_rank_0() + ) diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index f31551f22..e5a69f335 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -1,4 +1,3 @@ -import functools import warnings from typing import Optional @@ -103,7 +102,7 @@ class ColossalAIStrategy(DDPStrategy): # NOTE: dist should be initialized before calling get_current_device() if stage == 3: plugin_initializer = lambda: GeminiPlugin( - # gemini_config + # gemini_config device=get_current_device(), placement_policy=placement_policy, precision=precision, @@ -113,20 +112,20 @@ class ColossalAIStrategy(DDPStrategy): search_range_m=search_range_m, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m, - # zero_optim_config + # zero_optim_config gpu_margin_mem_ratio=gpu_margin_mem_ratio, - # optim_config + # optim_config **optim_kwargs) else: plugin_initializer = lambda: LowLevelZeroPlugin( - # zero_config + # zero_config stage=stage, precision=precision, - # zero_optim_config + # zero_optim_config reduce_bucket_size_in_m=reduce_bucket_size, overlap_communication=overlap_communication, cpu_offload=(placement_policy == 'cpu'), - # optim_config + # optim_config **optim_kwargs) super().__init__(seed, plugin_initializer) diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 9cccb5c92..c9fc8d0fe 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -3,6 +3,33 @@ from typing import Any import torch import torch.distributed as dist from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader + + +class CycledDataLoader: + """ + Why do we need this class? + In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain. + However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...) + NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior. + """ + + def __init__(self, + dataloader: DataLoader, + ) -> None: + self.dataloader = dataloader + + self.count = 0 + self.dataloader_iter = iter(dataloader) + + def next(self): + self.count += 1 + try: + return next(self.dataloader_iter) + except StopIteration: + self.count = 0 + self.dataloader_iter = iter(self.dataloader) + return next(self.dataloader_iter) def is_rank_0() -> bool: diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md index 72810738d..3e9d9c432 100644 --- a/applications/Chat/examples/README.md +++ b/applications/Chat/examples/README.md @@ -171,9 +171,8 @@ Pretrain dataset: the pretrain dataset including the instruction and correspondi - --pretrain_dataset: path of the ptx dataset, type=str, default=None - --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False - --num_episodes: num of episodes for training, type=int, default=10 -- --max_epochs: max epochs for training in one episode, type=int, default=5 -- --max_timesteps: max episodes in one batch, type=int, default=10 -- --update_timesteps: timesteps to update, type=int, default=10 +- --num_update_steps: number of steps to update policy per episode, type=int +- --num_collect_steps: number of steps to collect experience per episode, type=int - --train_batch_size: batch size while training, type=int, default=8 - --ptx_batch_size: batch size to compute ptx loss, type=int, default=1 - --experience_batch_size: batch size to make experience, type=int, default=8 diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index ba8470f38..00ed7aa36 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -171,7 +171,6 @@ def main(args): critic_optim, kl_coef=args.kl_coef, ptx_coef=args.ptx_coef, - max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, experience_batch_size=args.experience_batch_size, tokenizer=tokenize_fn, @@ -186,8 +185,8 @@ def main(args): trainer.fit(prompt_dataloader=prompt_dataloader, pretrain_dataloader=pretrain_dataloader, num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps) # save model checkpoint after fitting trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) @@ -215,9 +214,8 @@ if __name__ == '__main__': parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--num_collect_steps', type=int, default=10) + parser.add_argument('--num_update_steps', type=int, default=5) parser.add_argument('--train_batch_size', type=int, default=2) parser.add_argument('--ptx_batch_size', type=int, default=1) parser.add_argument('--experience_batch_size', type=int, default=8) diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh index 85728e958..4bf5524af 100755 --- a/applications/Chat/examples/test_ci.sh +++ b/applications/Chat/examples/test_ci.sh @@ -63,8 +63,8 @@ for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ --strategy $strategy --model $model \ - --num_episodes 1 --max_timesteps 2 \ - --update_timesteps 2 --max_epochs 1 --train_batch_size 2 + --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \ + --train_batch_size 2 done done @@ -149,8 +149,8 @@ rm -rf ${BASE}/rm_ckpt.pt torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ - --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ + --strategy colossalai_zero2 --num_episodes 1 \ + --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \ --pretrain 'facebook/opt-350m' --model opt \ --rm_pretrain 'facebook/opt-350m' \ --rm_path ${BASE}/rm_ckpt_opt.pt \ @@ -159,8 +159,8 @@ rm -rf ${BASE}/rm_ckpt_opt.pt torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ - --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ + --strategy colossalai_zero2 --num_episodes 1 \ + --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \ --pretrain 'gpt2' --model gpt2 \ --rm_pretrain 'gpt2' \ --rm_path ${BASE}/rm_ckpt_gpt.pt \ @@ -168,8 +168,8 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \ - --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ + --strategy colossalai_gemini --num_episodes 1 \ + --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \ --pretrain 'gpt2' --model gpt2 \ --rm_pretrain 'gpt2' \ --rm_path ${BASE}/rm_ckpt_gpt.pt \ diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index 2a47dda63..a9bc0e532 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -177,7 +177,6 @@ def main(args): critic_optim, kl_coef=args.kl_coef, ptx_coef=args.ptx_coef, - max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, max_length=args.max_seq_len, use_cache=True, @@ -192,8 +191,8 @@ def main(args): trainer.fit(prompt_dataloader=prompt_dataloader, pretrain_dataloader=pretrain_dataloader, num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps) # save model checkpoint after fitting strategy.save_model(actor, args.save_path, only_rank0=True) @@ -220,9 +219,8 @@ if __name__ == '__main__': parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--num_collect_steps', type=int, default=10) + parser.add_argument('--num_update_steps', type=int, default=5) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--ptx_batch_size', type=int, default=1) parser.add_argument('--experience_batch_size', type=int, default=8) diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh index 7f3b2636c..d04c41601 100755 --- a/applications/Chat/examples/train_prompts.sh +++ b/applications/Chat/examples/train_prompts.sh @@ -1,13 +1,13 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { local n=${1:-"9999"} echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" @@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2 # torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 -torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2 +torchrun --standalone --nproc_per_node=2 train_prompts.py \ + --pretrain_dataset /path/to/data.json \ + --prompt_dataset /path/to/data.json \ + --strategy colossalai_zero2 \ + --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \ + --train_batch_size 2 diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 2df3bc391..4a6851ab5 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -178,12 +178,11 @@ def train(args): optim=optim, lr_scheduler=lr_scheduler, loss_fn=loss_fn, - train_dataloader=train_dataloader, - valid_dataloader=valid_dataloader, - eval_dataloader=eval_dataloader, max_epochs=args.max_epochs) - trainer.fit() + trainer.fit(train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + eval_dataloader=eval_dataloader) # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index 717eb9531..967b7c277 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -170,12 +170,13 @@ def train(args): strategy=strategy, optim=optim, lr_scheduler=lr_scheduler, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps) - trainer.fit(logger=logger, use_wandb=args.use_wandb) + trainer.fit(train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + logger=logger, + use_wandb=args.use_wandb) # save model checkpoint after fitting on only rank0 strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)