From 0b2b454b97d55d1f974c28951fc5465b4ff24a8b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 06:48:54 +0000 Subject: [PATCH] fix eval --- .../ColossalChat/coati/trainer/sft.py | 82 +++++++++++++------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index ebdfd5024..6322cb8df 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -182,27 +182,63 @@ class SFTTrainer(SLTrainer): self.accumulative_meter.reset() self.model.eval() with torch.no_grad(): - step_bar = trange( - len(self.eval_dataloader), - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for batch in self.eval_dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.eval_dataloader) + step_bar = tqdm( + range(len(self.eval_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss_mean = all_reduce_mean(tensor=outputs.loss) - self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) - step_bar.update() - loss_mean = self.accumulative_meter.get("loss") - msg = "Evaluation Result:\n" - for tag in ["loss"]: - msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" - self.coordinator.print_on_master(msg) - os.makedirs(self.save_dir, exist_ok=True) - with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: - f.write(msg) - step_bar.close() + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"eval/loss": loss.item()}) + self.accumulative_meter.add("loss", loss.item()) + step_bar.update() + + if dist.get_rank() == dist.get_world_size() - 1: + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + print(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() + + else: + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss_mean = all_reduce_mean(tensor=outputs.loss) + self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) + step_bar.update() + + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close()