fix attn mask shape of gpt (#472)

pull/469/head
ver217 3 years ago committed by GitHub
parent 3cb3fc275e
commit 1559c0df41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -292,7 +292,7 @@ class GPT(nn.Module):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface # Adapted from huggingface
if attention_mask is not None: if attention_mask is not None:
batch_size = x.shape[0] batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask) attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

Loading…
Cancel
Save