|
|
@ -51,18 +51,6 @@ class GPTEmbedding(nn.Module): |
|
|
|
x = x + self.tokentype_embeddings(tokentype_ids) |
|
|
|
x = x + self.tokentype_embeddings(tokentype_ids) |
|
|
|
x = self.dropout(x) |
|
|
|
x = self.dropout(x) |
|
|
|
|
|
|
|
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask. |
|
|
|
|
|
|
|
# Sizes are [batch_size, 1, 1, to_seq_length] |
|
|
|
|
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
|
|
|
|
|
|
|
# Adapted from huggingface |
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1) |
|
|
|
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask) |
|
|
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
|
|
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility |
|
|
|
|
|
|
|
attention_mask = (1.0 - attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x, attention_mask |
|
|
|
return x, attention_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -355,6 +343,21 @@ class PipelineGPT(nn.Module): |
|
|
|
if self.first: |
|
|
|
if self.first: |
|
|
|
x, attention_mask = self.embed(input_ids, attention_mask) |
|
|
|
x, attention_mask = self.embed(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask. |
|
|
|
|
|
|
|
# Sizes are [batch_size, 1, 1, to_seq_length] |
|
|
|
|
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
|
|
|
|
|
|
|
# Adapted from huggingface |
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
if self.first: |
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
batch_size = x.shape[0] |
|
|
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1) |
|
|
|
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask) |
|
|
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
|
|
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility |
|
|
|
|
|
|
|
attention_mask = (1.0 - attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
|
for block in self.blocks: |
|
|
|
x, attention_mask = block(x, attention_mask) |
|
|
|
x, attention_mask = block(x, attention_mask) |
|
|
|
|
|
|
|
|
|
|
|