mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt]support opt & gpt for rm training (#2876)
parent
c52edcf0eb
commit
2e16f842a9
|
@ -1,6 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
|
|
|
@ -15,12 +15,16 @@ class GPTRM(RewardModel):
|
|||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -29,5 +33,6 @@ class GPTRM(RewardModel):
|
|||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
from transformers import OPTConfig, OPTModel
|
||||
|
||||
from .reward_model import RewardModel
|
||||
|
||||
|
@ -14,6 +13,7 @@ class OPTRM(RewardModel):
|
|||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
@ -21,6 +21,7 @@ class OPTRM(RewardModel):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
|
@ -29,5 +30,8 @@ class OPTRM(RewardModel):
|
|||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
|
|
|
@ -3,12 +3,13 @@ import argparse
|
|||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import BLOOMRM
|
||||
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
|
||||
from chatgpt.trainer import RewardModelTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from transformers import BloomTokenizerFast
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
@ -27,11 +28,30 @@ def train(args):
|
|||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
with strategy.model_init_context():
|
||||
model = BLOOMRM(pretrained=args.pretrain).cuda()
|
||||
max_len = 1024
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'opt':
|
||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
max_len = 512
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
|
@ -58,10 +78,10 @@ def train(args):
|
|||
|
||||
trainer.fit(use_lora=args.lora_rank)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
|
||||
else:
|
||||
torch.save(trainer.model, args.save_path)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -69,6 +89,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||
|
|
|
@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2
|
||||
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
|
||||
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
||||
|
|
Loading…
Reference in New Issue