[chatgpt] Support saving ckpt in examples (#2846)

* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2

* [chatgpt]fix rm eval typo

* fix rm eval

* fix pre commit

* add support of saving ckpt in examples

* fix single-gpu save
pull/2863/head
BlueRum 2023-02-22 10:00:26 +08:00 committed by GitHub
parent 597914317b
commit 34ca324b0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 0 deletions

View File

@ -97,6 +97,13 @@ def main(args):
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()

View File

@ -2,6 +2,7 @@ import argparse
from copy import deepcopy
import pandas as pd
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
@ -95,6 +96,12 @@ def main(args):
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':