From c7d6975d2984825df40bb86ac39fc1c3d137fe96 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 13 Sep 2023 15:57:16 +0800 Subject: [PATCH] [shardformer] fix GPT2DoubleHeadsModel (#4703) --- colossalai/shardformer/modeling/gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bc99be4cc..84deafefe 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -94,9 +94,9 @@ class GPT2PipelineForwards: if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] - batch_size = input_shape[0] device = hidden_states.device hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) + batch_size = hidden_states.shape[0] # GPT2Attention mask. if attention_mask is not None: