|
|
@ -165,8 +165,9 @@ class GPTBlock(CheckpointModule):
|
|
|
|
bias: bool = True,
|
|
|
|
bias: bool = True,
|
|
|
|
apply_post_layernorm: bool = False,
|
|
|
|
apply_post_layernorm: bool = False,
|
|
|
|
fuse_scale_mask_softmax: bool = False,
|
|
|
|
fuse_scale_mask_softmax: bool = False,
|
|
|
|
checkpoint: bool = False):
|
|
|
|
checkpoint: bool = False,
|
|
|
|
super().__init__(checkpoint)
|
|
|
|
activation_offload: bool = False):
|
|
|
|
|
|
|
|
super().__init__(checkpoint, activation_offload)
|
|
|
|
self.apply_post_layernorm = apply_post_layernorm
|
|
|
|
self.apply_post_layernorm = apply_post_layernorm
|
|
|
|
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
|
|
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
|
|
self.attn = GPTSelfAttention(dim=dim,
|
|
|
|
self.attn = GPTSelfAttention(dim=dim,
|
|
|
@ -252,7 +253,8 @@ class GPT(nn.Module):
|
|
|
|
bias: bool = True,
|
|
|
|
bias: bool = True,
|
|
|
|
apply_post_layernorm: bool = False,
|
|
|
|
apply_post_layernorm: bool = False,
|
|
|
|
fuse_scale_mask_softmax: bool = False,
|
|
|
|
fuse_scale_mask_softmax: bool = False,
|
|
|
|
checkpoint: bool = False) -> None:
|
|
|
|
checkpoint: bool = False,
|
|
|
|
|
|
|
|
activation_offload: bool = False) -> None:
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.embed = GPTEmbedding(embedding_dim=dim,
|
|
|
|
self.embed = GPTEmbedding(embedding_dim=dim,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
vocab_size=vocab_size,
|
|
|
@ -274,6 +276,7 @@ class GPT(nn.Module):
|
|
|
|
apply_post_layernorm=apply_post_layernorm,
|
|
|
|
apply_post_layernorm=apply_post_layernorm,
|
|
|
|
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
|
|
|
|
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
checkpoint=checkpoint,
|
|
|
|
|
|
|
|
activation_offload=activation_offload
|
|
|
|
) for _ in range(depth)
|
|
|
|
) for _ in range(depth)
|
|
|
|
])
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|