From f5ca0397dd1c0d725c9b8d0c63784c55666245a7 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Fri, 3 Mar 2023 15:58:16 +0800 Subject: [PATCH] [chatgpt] fix lora gemini conflict in RM training (#2984) * fix lora bug * polish * fix lora gemini --- applications/ChatGPT/chatgpt/nn/reward_model.py | 4 ++-- applications/ChatGPT/chatgpt/trainer/rm.py | 7 +------ applications/ChatGPT/examples/train_reward_model.py | 2 -- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/applications/ChatGPT/chatgpt/nn/reward_model.py b/applications/ChatGPT/chatgpt/nn/reward_model.py index 5108f61a6..27cd1ccae 100644 --- a/applications/ChatGPT/chatgpt/nn/reward_model.py +++ b/applications/ChatGPT/chatgpt/nn/reward_model.py @@ -24,14 +24,14 @@ class RewardModel(LoRAModule): lora_train_bias: str = 'none') -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model + self.convert_to_lora() + if value_head is not None: if value_head.out_features != 1: raise ValueError("The value head of reward model's output dim should be 1!") self.value_head = value_head - else: self.value_head = nn.Linear(model.config.n_embd, 1) - self.convert_to_lora() def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index 3286b8d8d..d44944aee 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -56,12 +56,7 @@ class RewardModelTrainer(ABC): desc='Train step of epoch %d' % epoch, disable=not is_rank_0()) # train - if use_lora > 0: - print("Using Lora") - lora.mark_only_lora_as_trainable(self.model.model) - - else: - self.model.train() + self.model.train() for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: chosen_ids = chosen_ids.squeeze(1).cuda() c_mask = c_mask.squeeze(1).cuda() diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index c17c6f393..44acba192 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -66,8 +66,6 @@ def train(args): train_dataset = RewardDataset(train_data, tokenizer, max_len) eval_dataset = RewardDataset(eval_data, tokenizer, max_len) - # batch_size here is expected to be C(k,2), k means # response of each prompt - # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1 trainer = RewardModelTrainer(model=model, strategy=strategy, optim=optim,