diff --git a/applications/ChatGPT/chatgpt/models/llama/__init__.py b/applications/ChatGPT/chatgpt/models/llama/__init__.py new file mode 100644 index 000000000..9b2a024af --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/__init__.py @@ -0,0 +1,5 @@ +from .llama_actor import LlamaActor +from .llama_critic import LlamaCritic +from .llama_rm import LlamaRM + +__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_actor.py b/applications/ChatGPT/chatgpt/models/llama/llama_actor.py new file mode 100644 index 000000000..2c7adb390 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_actor.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM + +from ..base import Actor + + +class LlamaActor(Actor): + """ + Llama Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_critic.py b/applications/ChatGPT/chatgpt/models/llama/llama_critic.py new file mode 100644 index 000000000..cd565031e --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_critic.py @@ -0,0 +1,42 @@ +from typing import Optional + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM + +from ..base import Critic + + +class LlamaCritic(Critic): + """ + Llama Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_rm.py b/applications/ChatGPT/chatgpt/models/llama/llama_rm.py new file mode 100644 index 000000000..81fa22d19 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/llama/llama_rm.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch.nn as nn +from transformers import LlamaConfig, LlamaForCausalLM + +from ..base import RewardModel + + +class LlamaRM(RewardModel): + """ + Llama Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) + + super().__init__(model, lora_rank, lora_train_bias)