diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py b/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py index 79ea9ffcd..9ed0ee6f7 100644 --- a/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py +++ b/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py @@ -1,4 +1,5 @@ from .base import Callback from .performance_evaluator import PerformanceEvaluator +from .save_checkpoint import SaveCheckpoint -__all__ = ['Callback', 'PerformanceEvaluator'] +__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint'] diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/save_checkpoint.py b/applications/ChatGPT/chatgpt/trainer/callbacks/save_checkpoint.py new file mode 100644 index 000000000..8f2beb12d --- /dev/null +++ b/applications/ChatGPT/chatgpt/trainer/callbacks/save_checkpoint.py @@ -0,0 +1,75 @@ +import os + +import torch.distributed as dist +from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy +from chatgpt.trainer.utils import is_rank_0 +from torch import nn +from torch.optim import Optimizer + +from .base import Callback + + +class SaveCheckpoint(Callback): + """ + The callback for saving checkpoint for chatgpt. + + Only support saving actor and critic model. + A typical architecture of the saved checkpoint would be: + - checkpoint + - episode_x + - actor.pt + - actor-optim-rank-0.pt + - actor-optim-rank-1.pt + - critic.pt + - critic-optim-rank-0.pt + - critic-optim-rank-1.pt + - ... + + Args: + path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint` + interval(int): the interval episode of saving checkpoint + strategy(Strategy): the strategy used to train + actor(nn.Module): the actor model + critic(nn.Module): the critic model + actor_optim(Optimizer): the optimizer of actor + critic_optim(Optimizer): the optimizer of critic + + """ + + def __init__(self, + path: str, + interval: int, + strategy: Strategy, + actor: nn.Module = None, + critic: nn.Module = None, + actor_optim: Optimizer = None, + critic_optim: Optimizer = None) -> None: + super().__init__() + self.path = os.path.join(path, 'checkpoint') + self.interval = interval + self.strategy = strategy + self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]} + + def on_episode_end(self, episode: int) -> None: + if (episode + 1) % self.interval != 0: + return + base_path = os.path.join(self.path, f'episode_{episode}') + if not os.path.exists(base_path): + os.makedirs(base_path) + + for model in self.model_dict.keys(): + + # save model + if self.model_dict[model][0] is None: + # saving only optimizer states is meaningless, so it would be skipped + continue + model_path = os.path.join(base_path, f'{model}.pt') + self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) + + # save optimizer + if self.model_dict[model][1] is None: + continue + only_rank0 = not isinstance(self.strategy, ColossalAIStrategy) + rank = 0 if is_rank_0() else dist.get_rank() + optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt') + self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index 35f647491..df64515a1 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -4,6 +4,7 @@ from copy import deepcopy import torch from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.trainer import PPOTrainer +from chatgpt.trainer.callbacks import SaveCheckpoint from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam from transformers import AutoTokenizer, BloomTokenizerFast @@ -71,26 +72,38 @@ def main(args): (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + callbacks = [] + if args.save_ckpt_path: + ckpt_callback = SaveCheckpoint( + args.save_ckpt_path, + args.save_ckpt_interval, + strategy, + actor, + critic, + actor_optim, + critic_optim, + ) + callbacks.append(ckpt_callback) + # configure trainer - trainer = PPOTrainer( - strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - experience_batch_size=args.experience_batch_size, - tokenizer=preprocess_batch, - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - ) + + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=preprocess_batch, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + callbacks=callbacks) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) trainer.fit(random_prompts, @@ -120,5 +133,10 @@ if __name__ == '__main__': parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--save_ckpt_path', + type=str, + default=None, + help="path to save checkpoint, None means not to save") + parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint") args = parser.parse_args() main(args)