|
|
@ -292,7 +292,7 @@ class GPT(nn.Module):
|
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
|
|
# Adapted from huggingface
|
|
|
|
# Adapted from huggingface
|
|
|
|
if attention_mask is not None:
|
|
|
|
if attention_mask is not None:
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
batch_size = input_ids.shape[0]
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask)
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask)
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|