From 5aee4261a60586b7cf5eda3992f247ff5569aedc Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 28 Oct 2024 06:06:07 +0000 Subject: [PATCH] [fix] fix test zerobubble --- colossalai/shardformer/modeling/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 47c17e749..7a04c5451 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -82,7 +82,7 @@ class LlamaPipelineForwards: elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape[:2] + batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: