mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] Support saving ckpt in examples (#2846)
* [chatgpt]fix train_rm bug with lora * [chatgpt]support colossalai strategy to train rm * fix pre-commit * fix pre-commit 2 * [chatgpt]fix rm eval typo * fix rm eval * fix pre commit * add support of saving ckpt in examples * fix single-gpu savepull/2863/head
parent
597914317b
commit
34ca324b0d
|
@ -97,6 +97,13 @@ def main(args):
|
|||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
|
@ -95,6 +96,12 @@ def main(args):
|
|||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue