Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

179 lines
7.5 KiB

import argparse
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.models.base import RewardModel
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
from torch.optim import Adam
from transformers import AutoTokenizer
from transformers.models.opt.configuration_opt import OPTConfig
from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters())
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
numel *= dist.get_world_size()
return numel
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask}
def print_rank_0(*args, **kwargs) -> None:
if dist.get_rank() == 0:
print(*args, **kwargs)
def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = ''
for name, numel in model_dict.items():
outputs += f'{name}: '
if numel >= B:
outputs += f'{numel / B:.2f} B\n'
elif numel >= M:
outputs += f'{numel / M:.2f} M\n'
elif numel >= K:
outputs += f'{numel / K:.2f} K\n'
else:
outputs += f'{numel}\n'
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
}
try:
return model_map[model_name]
except KeyError:
raise ValueError(f'Unknown model "{model_name}"')
def main(args):
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_gemini_cpu':
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
elif args.strategy == 'colossalai_zero1':
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
elif args.strategy == 'colossalai_zero1_cpu':
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
model_config = get_gpt_config(args.model)
with strategy.model_init_context():
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
initial_model = deepcopy(actor).cuda()
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
print_model_numel({
'Actor': actor_numel,
'Critic': critic_numel,
'Initial model': initial_model_numel,
'Reward model': reward_model_numel
})
performance_evaluator = PerformanceEvaluator(actor_numel,
critic_numel,
initial_model_numel,
reward_model_numel,
enable_grad_checkpoint=False,
ignore_episodes=1)
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=preprocess_batch,
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='125m')
parser.add_argument('--strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=4)
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
args = parser.parse_args()
main(args)