diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 35be7cb..68a0950 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -103,10 +103,16 @@ def evaluate_on_val_dls( if gpc.is_using_pp(): total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 - num_microbatches = total_val_bsz // data_cfg.micro_bsz - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]] - ) + if data_cfg.get("valid_pack_mode", None) is None or data_cfg.valid_pack_mode is False: + num_microbatches = total_val_bsz // data_cfg.micro_bsz + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]] + ) + else: + num_microbatches = total_val_bsz + tensor_shape = torch.Size( + [1, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]] + ) with switch_evaluation_pipeline_scheduler( trainer=trainer,