From 144731c35c47171ab675e5fc9557468450a5a666 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Mon, 9 Oct 2023 20:04:27 +0800 Subject: [PATCH] fix evaluation bug in pp --- internlm/initialize/launch.py | 2 +- internlm/utils/evaluation.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 3651a4c..5bd2b73 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -283,7 +283,7 @@ def args_sanity_check(): if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = "origin_tp" - if gpc.config.parallel["tensor"].get("mode", None) is 'fstp': + if gpc.config.parallel["tensor"].get("mode", None) == 'fstp': assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True." # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 2a11a47..148d19d 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -106,9 +106,15 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] - ) + if gpc.config.parallel['tensor']['mode'] == 'fstp': + sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE] + ) + else: + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + ) with switch_evaluation_pipeline_scheduler( trainer=trainer,