mirror of https://github.com/hpcaitech/ColossalAI
fixed gpt attention mask in pipeline (#430)
parent
f9c762df85
commit
0f5f5dd556
|
@ -51,18 +51,6 @@ class GPTEmbedding(nn.Module):
|
|||
x = x + self.tokentype_embeddings(tokentype_ids)
|
||||
x = self.dropout(x)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# Adapted from huggingface
|
||||
if attention_mask is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = col_nn.partition_batch(attention_mask)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
return x, attention_mask
|
||||
|
||||
|
||||
|
@ -355,6 +343,21 @@ class PipelineGPT(nn.Module):
|
|||
if self.first:
|
||||
x, attention_mask = self.embed(input_ids, attention_mask)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# Adapted from huggingface
|
||||
if attention_mask is not None:
|
||||
if self.first:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = x.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = col_nn.partition_batch(attention_mask)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
for block in self.blocks:
|
||||
x, attention_mask = block(x, attention_mask)
|
||||
|
||||
|
|
Loading…
Reference in New Issue