pull/6119/head
Tong Li 1 week ago
parent a8b4afb747
commit 375e356a16

@ -82,7 +82,7 @@ class ProcessRewardModelTrainer(SLTrainer):
self.wandb_run = wandb.init(project="Coati-PRM", sync_tensorboard=True)
def _train(self, epoch):
def _train(self, epoch: int):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
@ -129,6 +129,31 @@ class ProcessRewardModelTrainer(SLTrainer):
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
def _eval(epoch: int):
# TODO
pass
def _eval(self, epoch: int):
self.model.eval()
step_bar = tqdm.trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, self.device)
logits = self.model(batch["input_ids"])["logits"]
loss = self.loss_fn(batch["labels"], logits)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add(
"loss", loss_mean.to(torch.float16).item(), count_update=batch["input_ids"].size(0)
)
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

Loading…
Cancel
Save