embedding remove attn mask (#474)

pull/475/head
ver217 2022-03-21 14:53:23 +08:00 committed by GitHub
parent 7544347145
commit d70f43dd7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -43,7 +43,7 @@ class GPTEmbedding(nn.Module):
def word_embedding_weight(self): def word_embedding_weight(self):
return self.word_embeddings.weight return self.word_embeddings.weight
def forward(self, input_ids, attention_mask=None, position_ids=None, tokentype_ids=None): def forward(self, input_ids, position_ids=None, tokentype_ids=None):
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
@ -52,7 +52,7 @@ class GPTEmbedding(nn.Module):
x = x + self.tokentype_embeddings(tokentype_ids) x = x + self.tokentype_embeddings(tokentype_ids)
x = self.dropout(x) x = self.dropout(x)
return x, attention_mask return x
@LAYERS.register_module @LAYERS.register_module
@ -285,7 +285,7 @@ class GPT(nn.Module):
dtype=dtype) dtype=dtype)
def forward(self, input_ids, attention_mask=None): def forward(self, input_ids, attention_mask=None):
x, attention_mask = self.embed(input_ids, attention_mask) x = self.embed(input_ids)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
@ -362,7 +362,7 @@ class PipelineGPT(nn.Module):
def forward(self, x=None, input_ids=None, attention_mask=None): def forward(self, x=None, input_ids=None, attention_mask=None):
if self.first: if self.first:
x, attention_mask = self.embed(input_ids, attention_mask) x = self.embed(input_ids)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]