ColossalAI/examples/language/gpt/titans/model/pipeline_gpt1d.py

350 lines
13 KiB
Python

import inspect
# import model_zoo.gpt.gpt as col_gpt
import titans.model.gpt.gpt as col_gpt
import torch
import torch.nn as nn
from colossalai import kernel
from colossalai import nn as col_nn
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.legacy.pipeline.utils import partition_uniform
from colossalai.logging import get_dist_logger
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
__all__ = [
"GPT2_small_pipeline_1D",
"GPT2_exlarge_pipeline_1D",
"GPT3_pipeline_1D",
"GPT2_exlarge_pipeline_hybrid",
"GPT2_small_pipeline_hybrid",
"GPT3_pipeline_hybrid",
]
class GenericPipelineGPT(nn.Module):
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
super().__init__()
self.embedding = embedding
self.blocks = blocks
self.norm = norm
self.head = head
assert blocks is not None
if norm is not None or head is not None:
assert norm is not None and head is not None
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
if self.embedding is not None:
hidden_states = self.embedding(input_ids=input_ids)
batch_size = hidden_states.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
for block in self.blocks:
hidden_states, attention_mask = block(hidden_states, attention_mask)
if self.norm is not None:
hidden_states = self.head(self.norm(hidden_states))
return hidden_states
class PipelineGPT1D(GenericPipelineGPT):
def __init__(
self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
embed_drop_rate: float = 0.0,
act_func: str = "gelu",
mlp_ratio: int = 4.0,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
first: bool = False,
last: bool = False,
embed_split_hidden=False,
):
embedding = None
norm = None
head = None
embed_cls = VocabParallelEmbedding
head_cls = VocabParallelGPTLMHead1D
if embed_split_hidden:
embed_cls = HiddenParallelEmbedding
head_cls = HiddenParallelGPTLMHead1D
if first:
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
blocks = nn.ModuleList(
[
GPTTransformerLayer1D(
hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attn_drop_rate,
hidden_dropout_prob=drop_rate,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
)
for _ in range(num_layers)
]
)
if last:
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
class FusedPipelineGPT1D(GenericPipelineGPT):
def __init__(
self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
embed_drop_rate: float = 0.0,
act_func: str = "gelu",
mlp_ratio: int = 4.0,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
first: bool = False,
last: bool = False,
embed_split_hidden=False,
):
embedding = None
norm = None
head = None
embed_cls = VocabParallelEmbedding
head_cls = VocabParallelGPTLMHead1D
if embed_split_hidden:
embed_cls = HiddenParallelEmbedding
head_cls = HiddenParallelGPTLMHead1D
if first:
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
blocks = nn.ModuleList(
[
FusedGPTTransformerLayer1D(
hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attn_drop_rate,
hidden_dropout_prob=drop_rate,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
)
for _ in range(num_layers)
]
)
if last:
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
if self.embedding is not None:
hidden_states = self.embedding(input_ids=input_ids)
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
for block in self.blocks:
hidden_states, attention_mask = block(hidden_states, attention_mask)
if self.norm is not None:
hidden_states = self.head(self.norm(hidden_states))
return hidden_states
class PipelineGPTHybrid(GenericPipelineGPT):
def __init__(
self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
embed_drop_rate: float = 0.0,
act_func: str = "gelu",
mlp_ratio: int = 4,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
first: bool = False,
last: bool = False,
embed_split_hidden=False,
):
embedding = None
norm = None
head = None
if first:
embedding = col_gpt.GPTEmbedding(
hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype
)
blocks = nn.ModuleList(
[
col_gpt.GPTBlock(
hidden_size,
num_attention_heads,
mlp_ratio=mlp_ratio,
attention_dropout=attn_drop_rate,
dropout=drop_rate,
dtype=dtype,
checkpoint=checkpoint,
activation=nn.functional.gelu,
)
for _ in range(num_layers)
]
)
if last:
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
# hidden_size=hidden_size,
# dtype=dtype,
# bias=False)
head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
def _filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
logger = get_dist_logger()
if gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pipeline_size = 1
pipeline_rank = 0
rank = gpc.get_global_rank()
if pipeline_size > 1:
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
else:
wrapper = None
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
models = []
for start, end in parts:
kwargs["num_layers"] = end - start
kwargs["first"] = start == 0
kwargs["last"] = end == num_layers
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
if wrapper is not None:
if start == 0:
wrapper.register_module(chunk.embedding.word_embeddings)
elif end == num_layers:
wrapper.register_module(chunk.head)
models.append(chunk)
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
numel = 0
for _, param in model.named_parameters(recurse=True):
numel += param.numel()
logger.info(f"Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB")
return model
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device("cuda"), fused=False, **kwargs):
model = FusedPipelineGPT1D if fused else PipelineGPT1D
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
cfg = dict(
hidden_size=768,
num_attention_heads=12,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
)
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
cfg = dict(
hidden_size=1600,
num_attention_heads=32,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
)
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
cfg = dict(
hidden_size=12288,
num_attention_heads=96,
checkpoint=checkpoint,
max_position_embeddings=2048,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
)
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(
hidden_size=1600,
num_attention_heads=32,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
)
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(
hidden_size=768,
num_attention_heads=12,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
)
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(
hidden_size=12288,
num_attention_heads=96,
checkpoint=checkpoint,
max_position_embeddings=2048,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
)
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)