feat(eval): unify evaluation

pull/541/head
877825076@qq.com 2023-12-14 14:07:40 +08:00
parent bbb5651582
commit 136aa7c5a5
1 changed files with 13 additions and 1 deletions

View File

@ -65,7 +65,11 @@ def evaluate_on_val_dls(
step_count,
update_panel: bool = False,
streaming: bool = False,
val_steps: bool = -1,
num_consumed_tokens: int = 0,
writer_t=None,
):
"""Evaluation on different valid datasets."""
with switch_sequence_parallel_mode():
torch.cuda.empty_cache()
trainer.eval()
@ -139,6 +143,9 @@ def evaluate_on_val_dls(
if verbose:
val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item()
if val_idx >= val_steps > 0:
break
assert val_idx != -1
dist.barrier()
@ -148,12 +155,16 @@ def evaluate_on_val_dls(
infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_loss_from_metric": val_res["loss_from_metric"],
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
if writer:
writer.add_scalar(key=key, value=value, step=step_count)
if writer_t:
writer_t.add_scalar(key=key, value=value, step=num_consumed_tokens)
if update_panel:
logger.info(
@ -161,6 +172,7 @@ def evaluate_on_val_dls(
extra={
"step": step_count,
"val_loss": val_loss,
"val_loss_from_metric": val_res["loss_from_metric"],
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
},