import argparse from copy import deepcopy import torch import torch.distributed as dist import torch.nn as nn from coati.models.base import RewardModel from coati.models.opt import OPTActor, OPTCritic from coati.trainer import PPOTrainer from coati.trainer.callbacks import PerformanceEvaluator from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from torch.optim import Adam from transformers import AutoTokenizer from transformers.models.opt.configuration_opt import OPTConfig from colossalai.nn.optimizer import HybridAdam def get_model_numel(model: nn.Module, strategy: Strategy) -> int: numel = sum(p.numel() for p in model.parameters()) if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: numel *= dist.get_world_size() return numel def preprocess_batch(samples) -> dict: input_ids = torch.stack(samples) attention_mask = torch.ones_like(input_ids, dtype=torch.long) return {'input_ids': input_ids, 'attention_mask': attention_mask} def print_rank_0(*args, **kwargs) -> None: if dist.get_rank() == 0: print(*args, **kwargs) def print_model_numel(model_dict: dict) -> None: B = 1024**3 M = 1024**2 K = 1024 outputs = '' for name, numel in model_dict.items(): outputs += f'{name}: ' if numel >= B: outputs += f'{numel / B:.2f} B\n' elif numel >= M: outputs += f'{numel / M:.2f} M\n' elif numel >= K: outputs += f'{numel / K:.2f} K\n' else: outputs += f'{numel}\n' print_rank_0(outputs) def get_gpt_config(model_name: str) -> OPTConfig: model_map = { '125m': OPTConfig.from_pretrained('facebook/opt-125m'), '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), '13b': OPTConfig.from_pretrained('facebook/opt-13b'), } try: return model_map[model_name] except KeyError: raise ValueError(f'Unknown model "{model_name}"') def main(args): if args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) elif args.strategy == 'colossalai_gemini_cpu': strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') elif args.strategy == 'colossalai_zero2_cpu': strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') elif args.strategy == 'colossalai_zero1': strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') elif args.strategy == 'colossalai_zero1_cpu': strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) model_config = get_gpt_config(args.model) with strategy.model_init_context(): actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda() critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda() initial_model = deepcopy(actor).cuda() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) reward_model_numel = get_model_numel(reward_model, strategy) print_model_numel({ 'Actor': actor_numel, 'Critic': critic_numel, 'Initial model': initial_model_numel, 'Reward model': reward_model_numel }) performance_evaluator = PerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, reward_model_numel, enable_grad_checkpoint=False, ignore_episodes=1) if args.strategy.startswith('colossalai'): actor_optim = HybridAdam(actor.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6) else: actor_optim = Adam(actor.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6) tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') tokenizer.pad_token = tokenizer.eos_token (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) trainer = PPOTrainer(strategy, actor, critic, reward_model, initial_model, actor_optim, critic_optim, max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, experience_batch_size=args.experience_batch_size, tokenizer=preprocess_batch, max_length=512, do_sample=True, temperature=1.0, top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, callbacks=[performance_evaluator]) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) trainer.fit(random_prompts, num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', default='125m') parser.add_argument('--strategy', choices=[ 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' ], default='ddp') parser.add_argument('--num_episodes', type=int, default=3) parser.add_argument('--max_timesteps', type=int, default=8) parser.add_argument('--update_timesteps', type=int, default=8) parser.add_argument('--max_epochs', type=int, default=3) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--lora_rank', type=int, default=4) parser.add_argument('--cuda_mem_frac', type=float, default=1.0) args = parser.parse_args() main(args)