feat(eval): unify evaluation

pull/541/head
877825076@qq.com 2023-12-14 14:07:40 +08:00
parent bbb5651582
commit 136aa7c5a5
1 changed files with 13 additions and 1 deletions

View File

@ -65,7 +65,11 @@ def evaluate_on_val_dls(
step_count, step_count,
update_panel: bool = False, update_panel: bool = False,
streaming: bool = False, streaming: bool = False,
val_steps: bool = -1,
num_consumed_tokens: int = 0,
writer_t=None,
): ):
"""Evaluation on different valid datasets."""
with switch_sequence_parallel_mode(): with switch_sequence_parallel_mode():
torch.cuda.empty_cache() torch.cuda.empty_cache()
trainer.eval() trainer.eval()
@ -139,6 +143,9 @@ def evaluate_on_val_dls(
if verbose: if verbose:
val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item() val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item()
if val_idx >= val_steps > 0:
break
assert val_idx != -1 assert val_idx != -1
dist.barrier() dist.barrier()
@ -148,12 +155,16 @@ def evaluate_on_val_dls(
infos = { infos = {
"step": step_count, "step": step_count,
f"val/{val_name}_loss": val_loss, f"val/{val_name}_loss": val_loss,
f"val/{val_name}_loss_from_metric": val_res["loss_from_metric"],
f"val/{val_name}_acc": val_res["acc"], f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"], f"val/{val_name}_plex": val_res["perplexity"],
} }
for key, value in infos.items(): for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count) if writer:
writer.add_scalar(key=key, value=value, step=step_count)
if writer_t:
writer_t.add_scalar(key=key, value=value, step=num_consumed_tokens)
if update_panel: if update_panel:
logger.info( logger.info(
@ -161,6 +172,7 @@ def evaluate_on_val_dls(
extra={ extra={
"step": step_count, "step": step_count,
"val_loss": val_loss, "val_loss": val_loss,
"val_loss_from_metric": val_res["loss_from_metric"],
"val_acc": val_res["acc"], "val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"], "val_perplexity": val_res["perplexity"],
}, },