mirror of https://github.com/InternLM/InternLM
feat(eval): unify evaluation
parent
bbb5651582
commit
136aa7c5a5
|
@ -65,7 +65,11 @@ def evaluate_on_val_dls(
|
||||||
step_count,
|
step_count,
|
||||||
update_panel: bool = False,
|
update_panel: bool = False,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
|
val_steps: bool = -1,
|
||||||
|
num_consumed_tokens: int = 0,
|
||||||
|
writer_t=None,
|
||||||
):
|
):
|
||||||
|
"""Evaluation on different valid datasets."""
|
||||||
with switch_sequence_parallel_mode():
|
with switch_sequence_parallel_mode():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
|
@ -139,6 +143,9 @@ def evaluate_on_val_dls(
|
||||||
if verbose:
|
if verbose:
|
||||||
val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item()
|
val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item()
|
||||||
|
|
||||||
|
if val_idx >= val_steps > 0:
|
||||||
|
break
|
||||||
|
|
||||||
assert val_idx != -1
|
assert val_idx != -1
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
@ -148,12 +155,16 @@ def evaluate_on_val_dls(
|
||||||
infos = {
|
infos = {
|
||||||
"step": step_count,
|
"step": step_count,
|
||||||
f"val/{val_name}_loss": val_loss,
|
f"val/{val_name}_loss": val_loss,
|
||||||
|
f"val/{val_name}_loss_from_metric": val_res["loss_from_metric"],
|
||||||
f"val/{val_name}_acc": val_res["acc"],
|
f"val/{val_name}_acc": val_res["acc"],
|
||||||
f"val/{val_name}_plex": val_res["perplexity"],
|
f"val/{val_name}_plex": val_res["perplexity"],
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value in infos.items():
|
for key, value in infos.items():
|
||||||
|
if writer:
|
||||||
writer.add_scalar(key=key, value=value, step=step_count)
|
writer.add_scalar(key=key, value=value, step=step_count)
|
||||||
|
if writer_t:
|
||||||
|
writer_t.add_scalar(key=key, value=value, step=num_consumed_tokens)
|
||||||
|
|
||||||
if update_panel:
|
if update_panel:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -161,6 +172,7 @@ def evaluate_on_val_dls(
|
||||||
extra={
|
extra={
|
||||||
"step": step_count,
|
"step": step_count,
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
|
"val_loss_from_metric": val_res["loss_from_metric"],
|
||||||
"val_acc": val_res["acc"],
|
"val_acc": val_res["acc"],
|
||||||
"val_perplexity": val_res["perplexity"],
|
"val_perplexity": val_res["perplexity"],
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue