|
|
@ -21,6 +21,7 @@ __all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
|
class GPTEmbedding(nn.Module):
|
|
|
|
class GPTEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
embedding_dim: int,
|
|
|
|
embedding_dim: int,
|
|
|
|
vocab_size: int,
|
|
|
|
vocab_size: int,
|
|
|
@ -56,6 +57,7 @@ class GPTEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
|
class GPTSelfAttention(nn.Module):
|
|
|
|
class GPTSelfAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
dim: int,
|
|
|
|
dim: int,
|
|
|
|
num_heads: int,
|
|
|
|
num_heads: int,
|
|
|
@ -70,7 +72,8 @@ class GPTSelfAttention(nn.Module):
|
|
|
|
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
|
|
|
|
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
|
|
|
|
if fuse_scale_mask_softmax:
|
|
|
|
if fuse_scale_mask_softmax:
|
|
|
|
from colossalai.kernel import FusedScaleMaskSoftmax
|
|
|
|
from colossalai.kernel import FusedScaleMaskSoftmax
|
|
|
|
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
|
|
|
from colossalai.kernel.cuda_native.scaled_softmax import \
|
|
|
|
|
|
|
|
AttnMaskType
|
|
|
|
self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True,
|
|
|
|
self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True,
|
|
|
|
input_in_bf16=False,
|
|
|
|
input_in_bf16=False,
|
|
|
|
attn_mask_type=AttnMaskType.causal,
|
|
|
|
attn_mask_type=AttnMaskType.causal,
|
|
|
@ -113,7 +116,7 @@ class GPTSelfAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.matmul(x, v)
|
|
|
|
x = torch.matmul(x, v)
|
|
|
|
x = x.transpose(1, 2)
|
|
|
|
x = x.transpose(1, 2)
|
|
|
|
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
|
|
|
|
new_context_layer_shape = x.size()[:-2] + (all_head_size,)
|
|
|
|
x = x.reshape(new_context_layer_shape)
|
|
|
|
x = x.reshape(new_context_layer_shape)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.dense(x)
|
|
|
|
x = self.dense(x)
|
|
|
@ -124,6 +127,7 @@ class GPTSelfAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
|
class GPTMLP(nn.Module):
|
|
|
|
class GPTMLP(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
dim: int,
|
|
|
|
dim: int,
|
|
|
|
mlp_ratio: float,
|
|
|
|
mlp_ratio: float,
|
|
|
@ -148,6 +152,7 @@ class GPTMLP(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
|
class GPTBlock(CheckpointModule):
|
|
|
|
class GPTBlock(CheckpointModule):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
dim: int,
|
|
|
|
dim: int,
|
|
|
|
num_heads: int,
|
|
|
|
num_heads: int,
|
|
|
@ -194,6 +199,7 @@ class GPTBlock(CheckpointModule):
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
|
class GPTLMHead(nn.Module):
|
|
|
|
class GPTLMHead(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
dim: int,
|
|
|
|
dim: int,
|
|
|
|
vocab_size: int,
|
|
|
|
vocab_size: int,
|
|
|
@ -214,6 +220,7 @@ class GPTLMHead(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
@LOSSES.register_module
|
|
|
|
@LOSSES.register_module
|
|
|
|
class GPTLMLoss(nn.Module):
|
|
|
|
class GPTLMLoss(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.loss = col_nn.CrossEntropyLoss()
|
|
|
|
self.loss = col_nn.CrossEntropyLoss()
|
|
|
@ -227,6 +234,7 @@ class GPTLMLoss(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module
|
|
|
|
@MODELS.register_module
|
|
|
|
class GPT(nn.Module):
|
|
|
|
class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
vocab_size: int = 50304,
|
|
|
|
vocab_size: int = 50304,
|
|
|
|
max_position_embeddings: int = 1024,
|
|
|
|
max_position_embeddings: int = 1024,
|
|
|
@ -279,6 +287,18 @@ class GPT(nn.Module):
|
|
|
|
def forward(self, input_ids, attention_mask=None):
|
|
|
|
def forward(self, input_ids, attention_mask=None):
|
|
|
|
x, attention_mask = self.embed(input_ids, attention_mask)
|
|
|
|
x, attention_mask = self.embed(input_ids, attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
|
|
|
|
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
|
|
|
|
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
|
|
|
|
|
|
# Adapted from huggingface
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask)
|
|
|
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
|
|
|
|
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
|
|
|
|
|
|
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
|
|
|
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
|
for block in self.blocks:
|
|
|
|
x, attention_mask = block(x, attention_mask)
|
|
|
|
x, attention_mask = block(x, attention_mask)
|
|
|
|
|
|
|
|
|
|
|
@ -288,6 +308,7 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineGPT(nn.Module):
|
|
|
|
class PipelineGPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
vocab_size: int = 50304,
|
|
|
|
vocab_size: int = 50304,
|
|
|
|
max_position_embeddings: int = 1024,
|
|
|
|
max_position_embeddings: int = 1024,
|
|
|
@ -355,7 +376,7 @@ class PipelineGPT(nn.Module):
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask)
|
|
|
|
attention_mask = col_nn.partition_batch(attention_mask)
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
|
|
|
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
|
|
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
|
|
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
|
|
|
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
|
for block in self.blocks:
|
|
|
|