diff --git a/applications/ColossalChat/coati/trainer/prm.py b/applications/ColossalChat/coati/trainer/prm.py index 73fdb6c5e..3c916c308 100644 --- a/applications/ColossalChat/coati/trainer/prm.py +++ b/applications/ColossalChat/coati/trainer/prm.py @@ -82,7 +82,7 @@ class ProcessRewardModelTrainer(SLTrainer): self.wandb_run = wandb.init(project="Coati-PRM", sync_tensorboard=True) - def _train(self, epoch): + def _train(self, epoch: int): self.model.train() step_bar = tqdm.trange( len(self.train_dataloader) // self.accumulation_steps, @@ -129,6 +129,31 @@ class ProcessRewardModelTrainer(SLTrainer): f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" ) - def _eval(epoch: int): - # TODO - pass + def _eval(self, epoch: int): + self.model.eval() + + step_bar = tqdm.trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, self.device) + logits = self.model(batch["input_ids"])["logits"] + loss = self.loss_fn(batch["labels"], logits) + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add( + "loss", loss_mean.to(torch.float16).item(), count_update=batch["input_ids"].size(0) + ) + step_bar.update() + + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close()