Browse Source

Add RoBERTa for RLHF Stage 2 & 3 (test)

RoBERTa for RLHF Stage 2 & 3 (still in testing)
pull/3206/head
Camille Zhong 2 years ago
parent
commit
06741d894d
  1. 5
      applications/ChatGPT/chatgpt/models/roberta/__init__.py
  2. 35
      applications/ChatGPT/chatgpt/models/roberta/roberta_actor.py
  3. 38
      applications/ChatGPT/chatgpt/models/roberta/roberta_critic.py
  4. 39
      applications/ChatGPT/chatgpt/models/roberta/roberta_rm.py
  5. 11
      applications/ChatGPT/examples/train_dummy.py
  6. 10
      applications/ChatGPT/examples/train_reward_model.py

5
applications/ChatGPT/chatgpt/models/roberta/__init__.py

@ -0,0 +1,5 @@
from .roberta_actor import RoBERTaActor
from .roberta_critic import RoBERTaCritic
from .roberta_rm import RoBERTaRM
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']

35
applications/ChatGPT/chatgpt/models/roberta/roberta_actor.py

@ -0,0 +1,35 @@
from typing import Optional
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
from ..base import Actor
class RoBERTaActor(Actor):
"""
RoBERTa Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = RobertaForCausalLM(config)
else:
model = RobertaForCausalLM(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

38
applications/ChatGPT/chatgpt/models/roberta/roberta_critic.py

@ -0,0 +1,38 @@
from typing import Optional
import torch.nn as nn
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaModel
from ..base import Critic
class RoBERTaCritic(Critic):
"""
RoBERTa Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTa Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = RobertaModel.from_pretrained(pretrained)
elif config is not None:
model = RobertaModel(config)
else:
model = RobertaModel(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

39
applications/ChatGPT/chatgpt/models/roberta/roberta_rm.py

@ -0,0 +1,39 @@
from typing import Optional
import torch.nn as nn
from transformers import RobertaConfig, RobertaModel
from ..base import RewardModel
class RoBERTaRM(RewardModel):
"""
RoBERTa Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaModel.from_pretrained(pretrained)
elif config is not None:
model = RobertaModel(config)
else:
model = RobertaModel(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

11
applications/ChatGPT/examples/train_dummy.py

@ -6,11 +6,12 @@ from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
from chatgpt.models.gpt import GPTActor, GPTCritic from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.opt import OPTActor, OPTCritic from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.models.roberta import RoBERTaActor, RoBERTaCritic
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import SaveCheckpoint from chatgpt.trainer.callbacks import SaveCheckpoint
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -46,6 +47,9 @@ def main(args):
elif args.model == 'opt': elif args.model == 'opt':
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -69,6 +73,9 @@ def main(args):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer.pad_token = tokenizer.eos_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -128,7 +135,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive') default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt') parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)

10
applications/ChatGPT/examples/train_reward_model.py

@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM from chatgpt.models.opt import OPTRM
from chatgpt.models.roberta import RoBERTaRM
from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset from datasets import load_dataset
from random import randint from random import randint
from torch.optim import Adam from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -39,6 +40,8 @@ def train(args):
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2': elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
@ -54,6 +57,9 @@ def train(args):
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer.pad_token = tokenizer.eos_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len max_len = args.max_len
@ -119,7 +125,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive') default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'roberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)

Loading…
Cancel
Save