Browse Source

[chatgpt]fix train_rm bug with lora (#2741)

pull/2744/head
BlueRum 2 years ago committed by GitHub
parent
commit
648183a960
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      applications/ChatGPT/chatgpt/trainer/rm.py

2
applications/ChatGPT/chatgpt/trainer/rm.py

@ -43,7 +43,7 @@ class RewardModelTrainer(ABC):
# train
if use_lora > 0:
print("Using Lora")
lora.mark_only_lora_as_trainable(self.model)
lora.mark_only_lora_as_trainable(self.model.model)
else:
self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:

Loading…
Cancel
Save