[chatgpt] fix lora gemini conflict in RM training (#2984)

* fix lora bug

* polish

* fix lora gemini
pull/2993/head
BlueRum 2023-03-03 15:58:16 +08:00 committed by GitHub
parent 19ad49fb3b
commit f5ca0397dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 10 deletions

View File

@ -24,14 +24,14 @@ class RewardModel(LoRAModule):
lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
if value_head is not None:
if value_head.out_features != 1:
raise ValueError("The value head of reward model's output dim should be 1!")
self.value_head = value_head
else:
self.value_head = nn.Linear(model.config.n_embd, 1)
self.convert_to_lora()
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)

View File

@ -56,12 +56,7 @@ class RewardModelTrainer(ABC):
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
# train
if use_lora > 0:
print("Using Lora")
lora.mark_only_lora_as_trainable(self.model.model)
else:
self.model.train()
self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()

View File

@ -66,8 +66,6 @@ def train(args):
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
# batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,