|
|
|
@ -82,7 +82,7 @@ class ProcessRewardModelTrainer(SLTrainer):
|
|
|
|
|
|
|
|
|
|
self.wandb_run = wandb.init(project="Coati-PRM", sync_tensorboard=True)
|
|
|
|
|
|
|
|
|
|
def _train(self, epoch):
|
|
|
|
|
def _train(self, epoch: int):
|
|
|
|
|
self.model.train()
|
|
|
|
|
step_bar = tqdm.trange(
|
|
|
|
|
len(self.train_dataloader) // self.accumulation_steps,
|
|
|
|
@ -129,6 +129,31 @@ class ProcessRewardModelTrainer(SLTrainer):
|
|
|
|
|
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _eval(epoch: int):
|
|
|
|
|
# TODO
|
|
|
|
|
pass
|
|
|
|
|
def _eval(self, epoch: int):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
step_bar = tqdm.trange(
|
|
|
|
|
len(self.eval_dataloader),
|
|
|
|
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
|
|
|
|
disable=not is_rank_0(),
|
|
|
|
|
)
|
|
|
|
|
for batch in self.eval_dataloader:
|
|
|
|
|
batch = to_device(batch, self.device)
|
|
|
|
|
logits = self.model(batch["input_ids"])["logits"]
|
|
|
|
|
loss = self.loss_fn(batch["labels"], logits)
|
|
|
|
|
loss_mean = all_reduce_mean(tensor=loss)
|
|
|
|
|
self.accumulative_meter.add(
|
|
|
|
|
"loss", loss_mean.to(torch.float16).item(), count_update=batch["input_ids"].size(0)
|
|
|
|
|
)
|
|
|
|
|
step_bar.update()
|
|
|
|
|
|
|
|
|
|
loss_mean = self.accumulative_meter.get("loss")
|
|
|
|
|
msg = "Evaluation Result:\n"
|
|
|
|
|
for tag in ["loss"]:
|
|
|
|
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
|
|
|
|
self.coordinator.print_on_master(msg)
|
|
|
|
|
if self.save_dir is not None:
|
|
|
|
|
os.makedirs(self.save_dir, exist_ok=True)
|
|
|
|
|
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
|
|
|
|
f.write(msg)
|
|
|
|
|
step_bar.close()
|
|
|
|
|