From 304263c2cee51839898ac4bf8fedea1121dd5588 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Mar 2022 17:24:19 +0800 Subject: [PATCH] fix gpt attention mask (#461) --- model_zoo/gpt/gpt.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py index d544f9d71..dadbc152b 100644 --- a/model_zoo/gpt/gpt.py +++ b/model_zoo/gpt/gpt.py @@ -21,6 +21,7 @@ __all__ = [ @LAYERS.register_module class GPTEmbedding(nn.Module): + def __init__(self, embedding_dim: int, vocab_size: int, @@ -56,6 +57,7 @@ class GPTEmbedding(nn.Module): @LAYERS.register_module class GPTSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, @@ -70,7 +72,8 @@ class GPTSelfAttention(nn.Module): self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias) if fuse_scale_mask_softmax: from colossalai.kernel import FusedScaleMaskSoftmax - from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + from colossalai.kernel.cuda_native.scaled_softmax import \ + AttnMaskType self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True, input_in_bf16=False, attn_mask_type=AttnMaskType.causal, @@ -113,7 +116,7 @@ class GPTSelfAttention(nn.Module): x = torch.matmul(x, v) x = x.transpose(1, 2) - new_context_layer_shape = x.size()[:-2] + (all_head_size, ) + new_context_layer_shape = x.size()[:-2] + (all_head_size,) x = x.reshape(new_context_layer_shape) x = self.dense(x) @@ -124,6 +127,7 @@ class GPTSelfAttention(nn.Module): @LAYERS.register_module class GPTMLP(nn.Module): + def __init__(self, dim: int, mlp_ratio: float, @@ -148,6 +152,7 @@ class GPTMLP(nn.Module): @LAYERS.register_module class GPTBlock(CheckpointModule): + def __init__(self, dim: int, num_heads: int, @@ -194,6 +199,7 @@ class GPTBlock(CheckpointModule): @LAYERS.register_module class GPTLMHead(nn.Module): + def __init__(self, dim: int, vocab_size: int, @@ -214,6 +220,7 @@ class GPTLMHead(nn.Module): @LOSSES.register_module class GPTLMLoss(nn.Module): + def __init__(self): super().__init__() self.loss = col_nn.CrossEntropyLoss() @@ -227,6 +234,7 @@ class GPTLMLoss(nn.Module): @MODELS.register_module class GPT(nn.Module): + def __init__(self, vocab_size: int = 50304, max_position_embeddings: int = 1024, @@ -279,6 +287,18 @@ class GPT(nn.Module): def forward(self, input_ids, attention_mask=None): x, attention_mask = self.embed(input_ids, attention_mask) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # Adapted from huggingface + if attention_mask is not None: + batch_size = x.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = col_nn.partition_batch(attention_mask) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + for block in self.blocks: x, attention_mask = block(x, attention_mask) @@ -288,6 +308,7 @@ class GPT(nn.Module): class PipelineGPT(nn.Module): + def __init__(self, vocab_size: int = 50304, max_position_embeddings: int = 1024, @@ -355,7 +376,7 @@ class PipelineGPT(nn.Module): attention_mask = attention_mask.view(batch_size, -1) attention_mask = col_nn.partition_batch(attention_mask) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 for block in self.blocks: