ColossalAI/applications/Chat/coati/models/gpt/gpt_critic.py

38 lines
1.1 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
from typing import Optional
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from ..base import Critic
class GPTCritic(Critic):
"""
GPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
2023-03-28 12:25:36 +00:00
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
2023-03-28 12:25:36 +00:00
value_head = nn.Linear(model.config.n_embd, 1)
Update test_ci.sh update Update test_ci.sh Update test_ci.sh Update test_ci.sh Update test_ci.sh Update test_ci.sh Update test_ci.sh Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update test_ci.sh Update test_ci.sh update Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml update ci Update test_ci.sh Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml Update test_ci.sh Update test_ci.sh Update run_chatgpt_examples.yml Update test_ci.sh Update test_ci.sh Update test_ci.sh update test ci RoBERTa for RLHF Stage 2 & 3 (still in testing) Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit 06741d894dcbe958acd4e10d771f22275e20e368. Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci Update test_ci.sh Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit 06741d894dcbe958acd4e10d771f22275e20e368. Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci Update test_ci.sh Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. update roberta with coati chat ci update Revert "chat ci update" This reverts commit 17ae7ae01fa752bd3289fc39069868fde99cf846. [test]chat_update_ci Update test_ci.sh Update test_ci.sh test Update gpt_critic.py Update gpt_critic.py Update run_chatgpt_unit_tests.yml update test ci update update update update Update test_ci.sh update Update test_ci.sh Update test_ci.sh Update run_chatgpt_examples.yml Update run_chatgpt_examples.yml
2023-03-22 09:18:13 +00:00
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)