mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] Add saving ckpt callback for PPO (#2880)
* add checkpoint callback for chatgpt * add save ckpt callbacks for ppo --------- Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>pull/3026/head
parent
e588703454
commit
287d60499e
|
@ -1,4 +1,5 @@
|
|||
from .base import Callback
|
||||
from .performance_evaluator import PerformanceEvaluator
|
||||
from .save_checkpoint import SaveCheckpoint
|
||||
|
||||
__all__ = ['Callback', 'PerformanceEvaluator']
|
||||
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy
|
||||
from chatgpt.trainer.utils import is_rank_0
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
class SaveCheckpoint(Callback):
|
||||
"""
|
||||
The callback for saving checkpoint for chatgpt.
|
||||
|
||||
Only support saving actor and critic model.
|
||||
A typical architecture of the saved checkpoint would be:
|
||||
- checkpoint
|
||||
- episode_x
|
||||
- actor.pt
|
||||
- actor-optim-rank-0.pt
|
||||
- actor-optim-rank-1.pt
|
||||
- critic.pt
|
||||
- critic-optim-rank-0.pt
|
||||
- critic-optim-rank-1.pt
|
||||
- ...
|
||||
|
||||
Args:
|
||||
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
|
||||
interval(int): the interval episode of saving checkpoint
|
||||
strategy(Strategy): the strategy used to train
|
||||
actor(nn.Module): the actor model
|
||||
critic(nn.Module): the critic model
|
||||
actor_optim(Optimizer): the optimizer of actor
|
||||
critic_optim(Optimizer): the optimizer of critic
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
interval: int,
|
||||
strategy: Strategy,
|
||||
actor: nn.Module = None,
|
||||
critic: nn.Module = None,
|
||||
actor_optim: Optimizer = None,
|
||||
critic_optim: Optimizer = None) -> None:
|
||||
super().__init__()
|
||||
self.path = os.path.join(path, 'checkpoint')
|
||||
self.interval = interval
|
||||
self.strategy = strategy
|
||||
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
if (episode + 1) % self.interval != 0:
|
||||
return
|
||||
base_path = os.path.join(self.path, f'episode_{episode}')
|
||||
if not os.path.exists(base_path):
|
||||
os.makedirs(base_path)
|
||||
|
||||
for model in self.model_dict.keys():
|
||||
|
||||
# save model
|
||||
if self.model_dict[model][0] is None:
|
||||
# saving only optimizer states is meaningless, so it would be skipped
|
||||
continue
|
||||
model_path = os.path.join(base_path, f'{model}.pt')
|
||||
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
||||
|
||||
# save optimizer
|
||||
if self.model_dict[model][1] is None:
|
||||
continue
|
||||
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
|
||||
rank = 0 if is_rank_0() else dist.get_rank()
|
||||
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
||||
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
|
@ -4,6 +4,7 @@ from copy import deepcopy
|
|||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import SaveCheckpoint
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
|
@ -71,26 +72,38 @@ def main(args):
|
|||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
callbacks = []
|
||||
if args.save_ckpt_path:
|
||||
ckpt_callback = SaveCheckpoint(
|
||||
args.save_ckpt_path,
|
||||
args.save_ckpt_interval,
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
)
|
||||
callbacks.append(ckpt_callback)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=callbacks)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
|
@ -120,5 +133,10 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--save_ckpt_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to save checkpoint, None means not to save")
|
||||
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue