From ffd9a3cbc95bb9996cf56b9a05de342cfcc17e05 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Wed, 11 Oct 2023 19:30:41 +0800 Subject: [PATCH] [hotfix] fix bug in sequence parallel test (#4887) --- tests/test_shardformer/test_model/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0a2b151d4..66d77b48a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -160,7 +160,7 @@ def run_forward_backward_with_hybrid_plugin( input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) + data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,)) sharded_model.train() if booster.plugin.stage_manager is not None: