[chatgpt]fix train_rm bug with lora (#2741)

pull/2744/head
BlueRum 2 years ago committed by GitHub
parent b6e3b955c3
commit 648183a960
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,7 +43,7 @@ class RewardModelTrainer(ABC):
# train # train
if use_lora > 0: if use_lora > 0:
print("Using Lora") print("Using Lora")
lora.mark_only_lora_as_trainable(self.model) lora.mark_only_lora_as_trainable(self.model.model)
else: else:
self.model.train() self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:

Loading…
Cancel
Save