import argparse from copy import deepcopy import torch import torch.distributed as dist import torch.nn as nn from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from torch.optim import Adam from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam def get_model_numel(model: nn.Module, strategy: Strategy) -> int: numel = sum(p.numel() for p in model.parameters()) if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: numel *= dist.get_world_size() return numel def preprocess_batch(samples) -> dict: input_ids = torch.stack(samples) attention_mask = torch.ones_like(input_ids, dtype=torch.long) return {'input_ids': input_ids, 'attention_mask': attention_mask} def print_rank_0(*args, **kwargs) -> None: if dist.get_rank() == 0: print(*args, **kwargs) def print_model_numel(model_dict: dict) -> None: B = 1024**3 M = 1024**2 K = 1024 outputs = '' for name, numel in model_dict.items(): outputs += f'{name}: ' if numel >= B: outputs += f'{numel / B:.2f} B\n' elif numel >= M: outputs += f'{numel / M:.2f} M\n' elif numel >= K: outputs += f'{numel / K:.2f} K\n' else: outputs += f'{numel}\n' print_rank_0(outputs) def get_gpt_config(model_name: str) -> GPT2Config: model_map = { 's': GPT2Config(), 'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16), 'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20), 'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25), '2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16), '4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16), '6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16), '8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16), '10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16), '12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16), '15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16), '18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16), '20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16), '24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16), '28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16), '32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16), '36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16), '40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16), '175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96), } try: return model_map[model_name] except KeyError: raise ValueError(f'Unknown model "{model_name}"') def main(args): if args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) elif args.strategy == 'colossalai_gemini_cpu': strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') elif args.strategy == 'colossalai_zero2_cpu': strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') elif args.strategy == 'colossalai_zero1': strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') elif args.strategy == 'colossalai_zero1_cpu': strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') model_config = get_gpt_config(args.model) with strategy.model_init_context(): actor = GPTActor(config=model_config).cuda() critic = GPTCritic(config=model_config).cuda() initial_model = deepcopy(actor).cuda() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) reward_model_numel = get_model_numel(reward_model, strategy) print_model_numel({ 'Actor': actor_numel, 'Critic': critic_numel, 'Initial model': initial_model_numel, 'Reward model': reward_model_numel }) performance_evaluator = PerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, reward_model_numel, enable_grad_checkpoint=False, ignore_episodes=1) if args.strategy.startswith('colossalai'): actor_optim = HybridAdam(actor.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6) else: actor_optim = Adam(actor.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6) tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token trainer = PPOTrainer(strategy, actor, critic, reward_model, initial_model, actor_optim, critic_optim, max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, experience_batch_size=args.experience_batch_size, tokenizer=preprocess_batch, max_length=512, do_sample=True, temperature=1.0, top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, prepare_inputs_fn=gpt_prepare_inputs_fn, update_model_kwargs_fn=update_model_kwargs_fn, callbacks=[performance_evaluator]) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) trainer.fit(random_prompts, num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', default='s') parser.add_argument('--strategy', choices=[ 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' ], default='ddp') parser.add_argument('--num_episodes', type=int, default=3) parser.add_argument('--max_timesteps', type=int, default=8) parser.add_argument('--update_timesteps', type=int, default=8) parser.add_argument('--max_epochs', type=int, default=3) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8) args = parser.parse_args() main(args)