From 136aa7c5a5132ecd911ba5ca4088df2795f67b3d Mon Sep 17 00:00:00 2001 From: "877825076@qq.com" <877825076@qq.com> Date: Thu, 14 Dec 2023 14:07:40 +0800 Subject: [PATCH] feat(eval): unify evaluation --- internlm/utils/evaluation.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index a94784c..35be7cb 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -65,7 +65,11 @@ def evaluate_on_val_dls( step_count, update_panel: bool = False, streaming: bool = False, + val_steps: bool = -1, + num_consumed_tokens: int = 0, + writer_t=None, ): + """Evaluation on different valid datasets.""" with switch_sequence_parallel_mode(): torch.cuda.empty_cache() trainer.eval() @@ -139,6 +143,9 @@ def evaluate_on_val_dls( if verbose: val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item() + if val_idx >= val_steps > 0: + break + assert val_idx != -1 dist.barrier() @@ -148,12 +155,16 @@ def evaluate_on_val_dls( infos = { "step": step_count, f"val/{val_name}_loss": val_loss, + f"val/{val_name}_loss_from_metric": val_res["loss_from_metric"], f"val/{val_name}_acc": val_res["acc"], f"val/{val_name}_plex": val_res["perplexity"], } for key, value in infos.items(): - writer.add_scalar(key=key, value=value, step=step_count) + if writer: + writer.add_scalar(key=key, value=value, step=step_count) + if writer_t: + writer_t.add_scalar(key=key, value=value, step=num_consumed_tokens) if update_panel: logger.info( @@ -161,6 +172,7 @@ def evaluate_on_val_dls( extra={ "step": step_count, "val_loss": val_loss, + "val_loss_from_metric": val_res["loss_from_metric"], "val_acc": val_res["acc"], "val_perplexity": val_res["perplexity"], },