import argparse from copy import deepcopy import torch from chatgpt.models.base import RewardModel from chatgpt.models.bloom import BLOOMActor, BLOOMCritic from chatgpt.models.gpt import GPTActor, GPTCritic from chatgpt.models.opt import OPTActor, OPTCritic from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import SaveCheckpoint from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam def preprocess_batch(samples): input_ids = torch.stack(samples) attention_mask = torch.ones_like(input_ids, dtype=torch.long) return {'input_ids': input_ids, 'attention_mask': attention_mask} def main(args): # configure strategy if args.strategy == 'naive': strategy = NaiveStrategy() elif args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): if args.model == 'gpt2': actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'bloom': actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'opt': actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') initial_model = deepcopy(actor).to(torch.cuda.current_device()) reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device()) # configure optimizer if args.strategy.startswith('colossalai'): actor_optim = HybridAdam(actor.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6) else: actor_optim = Adam(actor.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6) # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") else: raise ValueError(f'Unsupported model "{args.model}"') (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) callbacks = [] if args.save_ckpt_path: ckpt_callback = SaveCheckpoint( args.save_ckpt_path, args.save_ckpt_interval, strategy, actor, critic, actor_optim, critic_optim, ) callbacks.append(ckpt_callback) # configure trainer trainer = PPOTrainer(strategy, actor, critic, reward_model, initial_model, actor_optim, critic_optim, max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, tokenizer=preprocess_batch, max_length=128, do_sample=True, temperature=1.0, top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, callbacks=callbacks) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) trainer.fit(random_prompts, num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) # save model checkpoint after fitting strategy.save_model(actor, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: strategy.save_optimizer(actor_optim, 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), only_rank0=False) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt') parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--num_episodes', type=int, default=50) parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument('--update_timesteps', type=int, default=10) parser.add_argument('--max_epochs', type=int, default=5) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--save_ckpt_path', type=str, default=None, help="path to save checkpoint, None means not to save") parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint") args = parser.parse_args() main(args)