[shardformer] fix GPT2DoubleHeadsModel (#4703)

pull/4709/head
flybird11111 2023-09-13 15:57:16 +08:00 committed by GitHub
parent 068372a738
commit c7d6975d29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -94,9 +94,9 @@ class GPT2PipelineForwards:
if hidden_states is None: if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
batch_size = input_shape[0]
device = hidden_states.device device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
batch_size = hidden_states.shape[0]
# GPT2Attention mask. # GPT2Attention mask.
if attention_mask is not None: if attention_mask is not None: