|
|
@ -292,7 +292,7 @@ class GPT(nn.Module): |
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
|
|
|
# Adapted from huggingface |
|
|
|
# Adapted from huggingface |
|
|
|
if attention_mask is not None: |
|
|
|
if attention_mask is not None: |
|
|
|
batch_size = x.shape[0] |
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
attention_mask = attention_mask.view(batch_size, -1) |
|
|
|
attention_mask = attention_mask.view(batch_size, -1) |
|
|
|
attention_mask = col_nn.partition_batch(attention_mask) |
|
|
|
attention_mask = col_nn.partition_batch(attention_mask) |
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|