mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
9.2 KiB
230 lines
9.2 KiB
from typing import Callable, List
|
|
from torch import dtype, nn
|
|
from colossalai import nn as col_nn
|
|
from colossalai.registry import LAYERS, MODELS
|
|
from colossalai.nn.layer import MoeModule
|
|
from colossalai.context import MOE_CONTEXT
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.layer.utils import CheckpointModule, divide
|
|
from model_zoo.gpt.gpt import GPTEmbedding, GPTSelfAttention, GPTMLP, GPTBlock, GPTLMHead
|
|
|
|
|
|
@LAYERS.register_module
|
|
class MOEGPTBlock(CheckpointModule):
|
|
|
|
def __init__(self,
|
|
num_experts: int,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float,
|
|
activation: Callable,
|
|
capacity_factor_train: float = 1.0,
|
|
capacity_factor_eval: float = 1.0,
|
|
use_residual: bool = False,
|
|
attention_dropout: float = 0.,
|
|
dropout: float = 0.,
|
|
layernorm_epsilon: float = 1e-5,
|
|
dtype: dtype = None,
|
|
bias: bool = True,
|
|
apply_post_layernorm: bool = False,
|
|
fuse_scale_mask_softmax: bool = False,
|
|
checkpoint: bool = False):
|
|
super().__init__(checkpoint)
|
|
self.apply_post_layernorm = apply_post_layernorm
|
|
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
self.attn = GPTSelfAttention(dim=dim,
|
|
num_heads=num_heads,
|
|
attention_dropout=attention_dropout,
|
|
dropout=dropout,
|
|
bias=bias,
|
|
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
|
|
dtype=dtype)
|
|
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
|
|
mpl_factory_dict = dict(dim=dim,
|
|
mlp_ratio=mlp_ratio,
|
|
activation=activation,
|
|
dropout=dropout,
|
|
dtype=dtype,
|
|
bias=bias)
|
|
|
|
self.mlp = MoeModule(dim_model=dim,
|
|
num_experts=num_experts,
|
|
top_k=1,
|
|
capacity_factor_train=capacity_factor_train,
|
|
capacity_factor_eval=capacity_factor_eval,
|
|
noisy_policy='Jitter',
|
|
use_residual=use_residual,
|
|
expert_cls=GPTMLP,
|
|
**mpl_factory_dict)
|
|
|
|
def _forward(self, x, attention_mask=None):
|
|
if not self.apply_post_layernorm:
|
|
residual = x
|
|
x = self.norm1(x)
|
|
if self.apply_post_layernorm:
|
|
residual = x
|
|
x = residual + self.attn(x, attention_mask)
|
|
|
|
if not self.apply_post_layernorm:
|
|
residual = x
|
|
x = self.norm2(x)
|
|
if self.apply_post_layernorm:
|
|
residual = x
|
|
x = residual + self.mlp(x)
|
|
|
|
return x, attention_mask
|
|
|
|
|
|
@MODELS.register_module
|
|
class MOEGPT(nn.Module):
|
|
|
|
def __init__(self,
|
|
num_experts: int or List[int],
|
|
use_residual: bool = False,
|
|
capacity_factor_train: float = 1.0,
|
|
capacity_factor_eval: float = 1.0,
|
|
vocab_size: int = 50304,
|
|
max_position_embeddings: int = 1024,
|
|
dim: int = 768,
|
|
num_heads: int = 12,
|
|
depth: int = 12,
|
|
mlp_ratio: float = 4.0,
|
|
dropout: float = 0.1,
|
|
embedding_dropout: float = 0.1,
|
|
attention_dropout: float = 0.1,
|
|
layernorm_epsilon: float = 1e-5,
|
|
activation: Callable = nn.functional.gelu,
|
|
padding_idx: int = None,
|
|
dtype: dtype = None,
|
|
bias: bool = True,
|
|
apply_post_layernorm: bool = False,
|
|
fuse_scale_mask_softmax: bool = False,
|
|
checkpoint: bool = False) -> None:
|
|
super().__init__()
|
|
|
|
half_depth = divide(depth, 2)
|
|
if isinstance(num_experts, list):
|
|
assert len(num_experts) == half_depth, \
|
|
"The length of num_experts should equal to the number of MOE layers"
|
|
num_experts_list = num_experts
|
|
else:
|
|
num_experts_list = [num_experts] * half_depth
|
|
|
|
self.embed = GPTEmbedding(embedding_dim=dim,
|
|
vocab_size=vocab_size,
|
|
max_position_embeddings=max_position_embeddings,
|
|
padding_idx=padding_idx,
|
|
dropout=embedding_dropout,
|
|
dtype=dtype)
|
|
|
|
block_list = []
|
|
block_factory_dict = dict(dim=dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
activation=activation,
|
|
attention_dropout=attention_dropout,
|
|
dropout=dropout,
|
|
layernorm_epsilon=layernorm_epsilon,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
apply_post_layernorm=apply_post_layernorm,
|
|
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
|
|
checkpoint=checkpoint)
|
|
|
|
for i in range(depth):
|
|
|
|
if i % 2 == 0:
|
|
block_module = GPTBlock(**block_factory_dict)
|
|
else:
|
|
num_experts = num_experts_list[i // 2]
|
|
block_module = MOEGPTBlock(num_experts=num_experts,
|
|
capacity_factor_train=capacity_factor_train,
|
|
capacity_factor_eval=capacity_factor_eval,
|
|
use_residual=use_residual,
|
|
**block_factory_dict)
|
|
|
|
block_list.append(block_module)
|
|
|
|
self.blocks = nn.ModuleList(block_list)
|
|
|
|
self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
|
|
self.head = GPTLMHead(dim=dim,
|
|
vocab_size=vocab_size,
|
|
word_embeeding_weight=self.embed.word_embedding_weight,
|
|
dtype=dtype)
|
|
|
|
def forward(self, input_ids, attention_mask=None):
|
|
MOE_CONTEXT.reset_loss()
|
|
x = self.embed(input_ids)
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# Adapted from huggingface
|
|
if attention_mask is not None:
|
|
batch_size = input_ids.shape[0]
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
attention_mask = col_nn.partition_batch(attention_mask)
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
|
|
|
for block in self.blocks:
|
|
x, attention_mask = block(x, attention_mask)
|
|
|
|
x = self.head(self.norm(x))
|
|
|
|
return x
|
|
|
|
|
|
def _create_moegpt_model(**model_kwargs):
|
|
model = MOEGPT(**model_kwargs)
|
|
return model
|
|
|
|
|
|
def _prmoe_check_sanity(kwargs_dict):
|
|
logger = get_dist_logger()
|
|
if not kwargs_dict.pop('use_residual', False):
|
|
logger.warning(
|
|
"If you want to use PR-MOE, please set 'use_residual' to True. "
|
|
"Otherwise, we'll force 'use_residual' to True.",
|
|
ranks=[0])
|
|
|
|
|
|
@MODELS.register_module
|
|
def prmoe_4b(**kwargs):
|
|
_prmoe_check_sanity(kwargs)
|
|
model_kwargs = dict(num_experts=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 64, 64],
|
|
use_residual=True,
|
|
dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
**kwargs)
|
|
return _create_moegpt_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def prmoe_31b(**kwargs):
|
|
_prmoe_check_sanity(kwargs)
|
|
model_kwargs = dict(num_experts=[64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 128, 128],
|
|
use_residual=True,
|
|
dim=2048,
|
|
depth=24,
|
|
num_heads=16,
|
|
**kwargs)
|
|
return _create_moegpt_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def prmoe_51b(**kwargs):
|
|
_prmoe_check_sanity(kwargs)
|
|
model_kwargs = dict(num_experts=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64],
|
|
use_residual=True,
|
|
dim=3072,
|
|
depth=32,
|
|
num_heads=24,
|
|
**kwargs)
|
|
return _create_moegpt_model(**model_kwargs)
|