From 1e1b9d2feabc6252818352fdd71772dd46fbe41d Mon Sep 17 00:00:00 2001 From: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com> Date: Wed, 22 Mar 2023 15:44:31 +0800 Subject: [PATCH] [chatgpt]support llama (#3070) --- .../ChatGPT/chatgpt/models/llama/__init__.py | 5 +++ .../chatgpt/models/llama/llama_actor.py | 38 +++++++++++++++++ .../chatgpt/models/llama/llama_critic.py | 42 +++++++++++++++++++ .../ChatGPT/chatgpt/models/llama/llama_rm.py | 41 ++++++++++++++++++ 4 files changed, 126 insertions(+) create mode 100644 applications/ChatGPT/chatgpt/models/llama/__init__.py create mode 100644 applications/ChatGPT/chatgpt/models/llama/llama_actor.py create mode 100644 applications/ChatGPT/chatgpt/models/llama/llama_critic.py create mode 100644 applications/ChatGPT/chatgpt/models/llama/llama_rm.py diff --git a/applications/ChatGPT/chatgpt/models/llama/__init__.py b/applications/ChatGPT/chatgpt/models/llama/__init__.py new file mode 100644 index 000000000..9b2a024af --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/__init__.py @@ -0,0 +1,5 @@ +from .llama_actor import LlamaActor +from .llama_critic import LlamaCritic +from .llama_rm import LlamaRM + +__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_actor.py b/applications/ChatGPT/chatgpt/models/llama/llama_actor.py new file mode 100644 index 000000000..2c7adb390 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_actor.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM + +from ..base import Actor + + +class LlamaActor(Actor): + """ + Llama Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_critic.py b/applications/ChatGPT/chatgpt/models/llama/llama_critic.py new file mode 100644 index 000000000..cd565031e --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_critic.py @@ -0,0 +1,42 @@ +from typing import Optional + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM + +from ..base import Critic + + +class LlamaCritic(Critic): + """ + Llama Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_rm.py b/applications/ChatGPT/chatgpt/models/llama/llama_rm.py new file mode 100644 index 000000000..81fa22d19 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_rm.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch.nn as nn +from transformers import LlamaConfig, LlamaForCausalLM + +from ..base import RewardModel + + +class LlamaRM(RewardModel): + """ + Llama Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) + + super().__init__(model, lora_rank, lora_train_bias)