mirror of https://github.com/hpcaitech/ColossalAI
323 lines
13 KiB
Python
323 lines
13 KiB
Python
|
import inspect
|
||
|
|
||
|
# import model_zoo.gpt.gpt as col_gpt
|
||
|
import titans.model.gpt.gpt as col_gpt
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from colossalai import kernel
|
||
|
from colossalai import nn as col_nn
|
||
|
from colossalai.context.parallel_mode import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.logging import get_dist_logger
|
||
|
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||
|
from colossalai.pipeline.utils import partition_uniform
|
||
|
|
||
|
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
|
||
|
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
|
||
|
|
||
|
__all__ = [
|
||
|
'GPT2_small_pipeline_1D',
|
||
|
'GPT2_exlarge_pipeline_1D',
|
||
|
'GPT3_pipeline_1D',
|
||
|
'GPT2_exlarge_pipeline_hybrid',
|
||
|
'GPT2_small_pipeline_hybrid',
|
||
|
'GPT3_pipeline_hybrid',
|
||
|
]
|
||
|
|
||
|
|
||
|
class GenericPipelineGPT(nn.Module):
|
||
|
|
||
|
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
|
||
|
super().__init__()
|
||
|
self.embedding = embedding
|
||
|
self.blocks = blocks
|
||
|
self.norm = norm
|
||
|
self.head = head
|
||
|
assert blocks is not None
|
||
|
if norm is not None or head is not None:
|
||
|
assert norm is not None and head is not None
|
||
|
|
||
|
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||
|
if self.embedding is not None:
|
||
|
hidden_states = self.embedding(input_ids=input_ids)
|
||
|
batch_size = hidden_states.shape[0]
|
||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||
|
attention_mask = attention_mask[:, None, None, :]
|
||
|
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||
|
for block in self.blocks:
|
||
|
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||
|
if self.norm is not None:
|
||
|
hidden_states = self.head(self.norm(hidden_states))
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class PipelineGPT1D(GenericPipelineGPT):
|
||
|
|
||
|
def __init__(self,
|
||
|
num_layers: int = 12,
|
||
|
hidden_size: int = 768,
|
||
|
num_attention_heads: int = 12,
|
||
|
vocab_size: int = 50304,
|
||
|
embed_drop_rate: float = 0.,
|
||
|
act_func: str = 'gelu',
|
||
|
mlp_ratio: int = 4.0,
|
||
|
attn_drop_rate: float = 0.,
|
||
|
drop_rate: float = 0.,
|
||
|
dtype: torch.dtype = torch.float,
|
||
|
checkpoint: bool = False,
|
||
|
max_position_embeddings: int = 1024,
|
||
|
layer_norm_epsilon: float = 1e-5,
|
||
|
apply_post_layer_norm: bool = False,
|
||
|
first: bool = False,
|
||
|
last: bool = False,
|
||
|
embed_split_hidden=False):
|
||
|
embedding = None
|
||
|
norm = None
|
||
|
head = None
|
||
|
embed_cls = VocabParallelEmbedding
|
||
|
head_cls = VocabParallelGPTLMHead1D
|
||
|
if embed_split_hidden:
|
||
|
embed_cls = HiddenParallelEmbedding
|
||
|
head_cls = HiddenParallelGPTLMHead1D
|
||
|
if first:
|
||
|
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||
|
blocks = nn.ModuleList([
|
||
|
GPTTransformerLayer1D(hidden_size,
|
||
|
num_attention_heads,
|
||
|
act_func=act_func,
|
||
|
mlp_ratio=mlp_ratio,
|
||
|
attention_dropout_prob=attn_drop_rate,
|
||
|
hidden_dropout_prob=drop_rate,
|
||
|
dtype=dtype,
|
||
|
checkpoint=checkpoint,
|
||
|
max_position_embeddings=max_position_embeddings,
|
||
|
layer_norm_epsilon=layer_norm_epsilon,
|
||
|
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||
|
])
|
||
|
if last:
|
||
|
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||
|
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||
|
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||
|
|
||
|
|
||
|
class FusedPipelineGPT1D(GenericPipelineGPT):
|
||
|
|
||
|
def __init__(self,
|
||
|
num_layers: int = 12,
|
||
|
hidden_size: int = 768,
|
||
|
num_attention_heads: int = 12,
|
||
|
vocab_size: int = 50304,
|
||
|
embed_drop_rate: float = 0.,
|
||
|
act_func: str = 'gelu',
|
||
|
mlp_ratio: int = 4.0,
|
||
|
attn_drop_rate: float = 0.,
|
||
|
drop_rate: float = 0.,
|
||
|
dtype: torch.dtype = torch.float,
|
||
|
checkpoint: bool = False,
|
||
|
max_position_embeddings: int = 1024,
|
||
|
layer_norm_epsilon: float = 1e-5,
|
||
|
apply_post_layer_norm: bool = False,
|
||
|
first: bool = False,
|
||
|
last: bool = False,
|
||
|
embed_split_hidden=False):
|
||
|
embedding = None
|
||
|
norm = None
|
||
|
head = None
|
||
|
embed_cls = VocabParallelEmbedding
|
||
|
head_cls = VocabParallelGPTLMHead1D
|
||
|
if embed_split_hidden:
|
||
|
embed_cls = HiddenParallelEmbedding
|
||
|
head_cls = HiddenParallelGPTLMHead1D
|
||
|
if first:
|
||
|
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||
|
blocks = nn.ModuleList([
|
||
|
FusedGPTTransformerLayer1D(hidden_size,
|
||
|
num_attention_heads,
|
||
|
act_func=act_func,
|
||
|
mlp_ratio=mlp_ratio,
|
||
|
attention_dropout_prob=attn_drop_rate,
|
||
|
hidden_dropout_prob=drop_rate,
|
||
|
dtype=dtype,
|
||
|
checkpoint=checkpoint,
|
||
|
max_position_embeddings=max_position_embeddings,
|
||
|
layer_norm_epsilon=layer_norm_epsilon,
|
||
|
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||
|
])
|
||
|
if last:
|
||
|
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||
|
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||
|
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||
|
|
||
|
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||
|
if self.embedding is not None:
|
||
|
hidden_states = self.embedding(input_ids=input_ids)
|
||
|
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||
|
for block in self.blocks:
|
||
|
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||
|
if self.norm is not None:
|
||
|
hidden_states = self.head(self.norm(hidden_states))
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class PipelineGPTHybrid(GenericPipelineGPT):
|
||
|
|
||
|
def __init__(self,
|
||
|
num_layers: int = 12,
|
||
|
hidden_size: int = 768,
|
||
|
num_attention_heads: int = 12,
|
||
|
vocab_size: int = 50304,
|
||
|
embed_drop_rate: float = 0.,
|
||
|
act_func: str = 'gelu',
|
||
|
mlp_ratio: int = 4,
|
||
|
attn_drop_rate: float = 0.,
|
||
|
drop_rate: float = 0.,
|
||
|
dtype: torch.dtype = torch.float,
|
||
|
checkpoint: bool = False,
|
||
|
max_position_embeddings: int = 1024,
|
||
|
layer_norm_epsilon: float = 1e-5,
|
||
|
apply_post_layer_norm: bool = False,
|
||
|
first: bool = False,
|
||
|
last: bool = False,
|
||
|
embed_split_hidden=False):
|
||
|
embedding = None
|
||
|
norm = None
|
||
|
head = None
|
||
|
if first:
|
||
|
embedding = col_gpt.GPTEmbedding(hidden_size,
|
||
|
vocab_size,
|
||
|
max_position_embeddings,
|
||
|
dropout=embed_drop_rate,
|
||
|
dtype=dtype)
|
||
|
blocks = nn.ModuleList([
|
||
|
col_gpt.GPTBlock(hidden_size,
|
||
|
num_attention_heads,
|
||
|
mlp_ratio=mlp_ratio,
|
||
|
attention_dropout=attn_drop_rate,
|
||
|
dropout=drop_rate,
|
||
|
dtype=dtype,
|
||
|
checkpoint=checkpoint,
|
||
|
activation=nn.functional.gelu) for _ in range(num_layers)
|
||
|
])
|
||
|
if last:
|
||
|
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||
|
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
|
||
|
# hidden_size=hidden_size,
|
||
|
# dtype=dtype,
|
||
|
# bias=False)
|
||
|
head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
|
||
|
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||
|
|
||
|
|
||
|
def _filter_kwargs(func, kwargs):
|
||
|
sig = inspect.signature(func)
|
||
|
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||
|
|
||
|
|
||
|
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||
|
logger = get_dist_logger()
|
||
|
|
||
|
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||
|
else:
|
||
|
pipeline_size = 1
|
||
|
pipeline_rank = 0
|
||
|
rank = gpc.get_global_rank()
|
||
|
|
||
|
if pipeline_size > 1:
|
||
|
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||
|
else:
|
||
|
wrapper = None
|
||
|
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||
|
models = []
|
||
|
for start, end in parts:
|
||
|
kwargs['num_layers'] = end - start
|
||
|
kwargs['first'] = start == 0
|
||
|
kwargs['last'] = end == num_layers
|
||
|
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||
|
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
|
||
|
|
||
|
if wrapper is not None:
|
||
|
if start == 0:
|
||
|
wrapper.register_module(chunk.embedding.word_embeddings)
|
||
|
elif end == num_layers:
|
||
|
wrapper.register_module(chunk.head)
|
||
|
models.append(chunk)
|
||
|
if len(models) == 1:
|
||
|
model = models[0]
|
||
|
else:
|
||
|
model = nn.ModuleList(models)
|
||
|
|
||
|
numel = 0
|
||
|
for _, param in model.named_parameters(recurse=True):
|
||
|
numel += param.numel()
|
||
|
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
|
||
|
return model
|
||
|
|
||
|
|
||
|
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
|
||
|
model = FusedPipelineGPT1D if fused else PipelineGPT1D
|
||
|
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
|
||
|
|
||
|
|
||
|
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||
|
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
|
||
|
|
||
|
|
||
|
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||
|
cfg = dict(hidden_size=768,
|
||
|
num_attention_heads=12,
|
||
|
checkpoint=checkpoint,
|
||
|
dtype=dtype,
|
||
|
embed_split_hidden=embed_split_hidden)
|
||
|
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
|
||
|
|
||
|
|
||
|
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||
|
cfg = dict(hidden_size=1600,
|
||
|
num_attention_heads=32,
|
||
|
checkpoint=checkpoint,
|
||
|
dtype=dtype,
|
||
|
embed_split_hidden=embed_split_hidden)
|
||
|
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
|
||
|
|
||
|
|
||
|
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||
|
cfg = dict(hidden_size=12288,
|
||
|
num_attention_heads=96,
|
||
|
checkpoint=checkpoint,
|
||
|
max_position_embeddings=2048,
|
||
|
dtype=dtype,
|
||
|
embed_split_hidden=embed_split_hidden)
|
||
|
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
|
||
|
|
||
|
|
||
|
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||
|
cfg = dict(hidden_size=1600,
|
||
|
num_attention_heads=32,
|
||
|
checkpoint=checkpoint,
|
||
|
dtype=dtype,
|
||
|
embed_split_hidden=embed_split_hidden)
|
||
|
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
|
||
|
|
||
|
|
||
|
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||
|
cfg = dict(hidden_size=768,
|
||
|
num_attention_heads=12,
|
||
|
checkpoint=checkpoint,
|
||
|
dtype=dtype,
|
||
|
embed_split_hidden=embed_split_hidden)
|
||
|
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
|
||
|
|
||
|
|
||
|
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||
|
cfg = dict(hidden_size=12288,
|
||
|
num_attention_heads=96,
|
||
|
checkpoint=checkpoint,
|
||
|
max_position_embeddings=2048,
|
||
|
dtype=dtype,
|
||
|
embed_split_hidden=embed_split_hidden)
|
||
|
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
|