mirror of https://github.com/hpcaitech/ColossalAI
Add RoBERTa for RLHF Stage 2 & 3 (test)
RoBERTa for RLHF Stage 2 & 3 (still in testing)pull/3206/head
parent
1e1b9d2fea
commit
06741d894d
|
@ -0,0 +1,5 @@
|
||||||
|
from .roberta_actor import RoBERTaActor
|
||||||
|
from .roberta_critic import RoBERTaCritic
|
||||||
|
from .roberta_rm import RoBERTaRM
|
||||||
|
|
||||||
|
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']
|
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||||
|
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
|
||||||
|
|
||||||
|
from ..base import Actor
|
||||||
|
|
||||||
|
class RoBERTaActor(Actor):
|
||||||
|
"""
|
||||||
|
RoBERTa Actor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str): Pretrained model name or path.
|
||||||
|
config (RoBERTaConfig): Model config.
|
||||||
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the low-rank approximation.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
config: Optional[RobertaConfig] = None,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
|
if pretrained is not None:
|
||||||
|
model = RobertaForCausalLM.from_pretrained(pretrained)
|
||||||
|
elif config is not None:
|
||||||
|
model = RobertaForCausalLM(config)
|
||||||
|
else:
|
||||||
|
model = RobertaForCausalLM(RobertaConfig())
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,38 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||||
|
from transformers.models.roberta.modeling_roberta import RobertaModel
|
||||||
|
|
||||||
|
from ..base import Critic
|
||||||
|
|
||||||
|
|
||||||
|
class RoBERTaCritic(Critic):
|
||||||
|
"""
|
||||||
|
RoBERTa Critic model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str): Pretrained model name or path.
|
||||||
|
config (RoBERTa Config): Model config.
|
||||||
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the low-rank approximation.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
config: Optional[RobertaConfig] = None,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none',
|
||||||
|
**kwargs) -> None:
|
||||||
|
if pretrained is not None:
|
||||||
|
model = RobertaModel.from_pretrained(pretrained)
|
||||||
|
elif config is not None:
|
||||||
|
model = RobertaModel(config)
|
||||||
|
else:
|
||||||
|
model = RobertaModel(RobertaConfig())
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
|
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -0,0 +1,39 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import RobertaConfig, RobertaModel
|
||||||
|
|
||||||
|
|
||||||
|
from ..base import RewardModel
|
||||||
|
|
||||||
|
|
||||||
|
class RoBERTaRM(RewardModel):
|
||||||
|
"""
|
||||||
|
RoBERTa Reward model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str): Pretrained model name or path.
|
||||||
|
config (RoBERTaConfig): Model config.
|
||||||
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the low-rank approximation.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
config: Optional[RobertaConfig] = None,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
|
if pretrained is not None:
|
||||||
|
model = RobertaModel.from_pretrained(pretrained)
|
||||||
|
elif config is not None:
|
||||||
|
model = RobertaModel(config)
|
||||||
|
else:
|
||||||
|
model = RobertaModel(RobertaConfig())
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
|
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
|
||||||
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -6,11 +6,12 @@ from chatgpt.models.base import RewardModel
|
||||||
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
|
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
|
||||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
from chatgpt.models.gpt import GPTActor, GPTCritic
|
||||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
from chatgpt.models.opt import OPTActor, OPTCritic
|
||||||
|
from chatgpt.models.roberta import RoBERTaActor, RoBERTaCritic
|
||||||
from chatgpt.trainer import PPOTrainer
|
from chatgpt.trainer import PPOTrainer
|
||||||
from chatgpt.trainer.callbacks import SaveCheckpoint
|
from chatgpt.trainer.callbacks import SaveCheckpoint
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
@ -46,6 +47,9 @@ def main(args):
|
||||||
elif args.model == 'opt':
|
elif args.model == 'opt':
|
||||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
|
elif args.model == 'roberta':
|
||||||
|
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
|
critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
@ -69,6 +73,9 @@ def main(args):
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'opt':
|
elif args.model == 'opt':
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
elif args.model == 'roberta':
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
@ -128,7 +135,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--strategy',
|
parser.add_argument('--strategy',
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
default='naive')
|
default='naive')
|
||||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
|
|
@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel
|
||||||
from chatgpt.models.bloom import BLOOMRM
|
from chatgpt.models.bloom import BLOOMRM
|
||||||
from chatgpt.models.gpt import GPTRM
|
from chatgpt.models.gpt import GPTRM
|
||||||
from chatgpt.models.opt import OPTRM
|
from chatgpt.models.opt import OPTRM
|
||||||
|
from chatgpt.models.roberta import RoBERTaRM
|
||||||
from chatgpt.trainer import RewardModelTrainer
|
from chatgpt.trainer import RewardModelTrainer
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from random import randint
|
from random import randint
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
@ -39,6 +40,8 @@ def train(args):
|
||||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
elif args.model == 'gpt2':
|
elif args.model == 'gpt2':
|
||||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
|
elif args.model == 'roberta':
|
||||||
|
model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
@ -54,6 +57,9 @@ def train(args):
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||||
elif args.model == 'opt':
|
elif args.model == 'opt':
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
elif args.model == 'roberta':
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
@ -119,7 +125,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--strategy',
|
parser.add_argument('--strategy',
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
default='naive')
|
default='naive')
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'roberta'], default='bloom')
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--model_path', type=str, default=None)
|
parser.add_argument('--model_path', type=str, default=None)
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
|
Loading…
Reference in New Issue