InternLM/internlm/model/modeling_llama_moe.py

1256 lines
59 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) InternLM. All rights reserved.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module
from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
scaled_init_method_uniform,
uniform_,
)
from internlm.model.embedding import Embedding1D, RotaryEmbedding
from internlm.model.linear import (
ColumnParallelLinearTorch,
FeedForward,
RewardModelLinear,
RowParallelLinearTorch,
ScaleColumnParallelLinear,
)
from internlm.model.moe import MoE
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER
try:
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from flash_attn.modules.mlp import ParallelFusedMLP
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
pass
MODEL_TYPE = "LLAMA_MoE"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
class MHA(nn.Module):
"""
Multi-head self-attention and cross-attention.
Args:
embed_dim (int): The dimention of hidden state.
num_heads (int): The number of attention heads.
num_kv_heads (int): The number of kv attention heads.
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
output projection. True by default.
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
softmax_scale (float): The temperature to use for the softmax attention.
causal (boolean): Whether to apply causal attention mask. False by default.
layer_idx (int): The index of current layer. None by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
True by default.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
rot_embed_HF_impl: rotary embedding hf implementation. False by default.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
dropout: float = 0.0,
softmax_scale: float = None,
causal: bool = False,
layer_idx: int = None,
rope_base: int = 10000,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
rot_embed_HF_impl: Optional[bool] = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
self.num_kv_heads = num_kv_heads
self.kv_dim = self.head_dim * num_kv_heads
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.dtype = dtype
self.rot_embed_HF_impl = rot_embed_HF_impl
sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)
if self.rotary_emb_dim > 0:
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device
)
# notice here should change bias=True
self.wq = ColumnParallelLinearTorch(
embed_dim,
embed_dim,
process_group,
bias=bias,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
self.wk = ColumnParallelLinearTorch(
embed_dim,
self.kv_dim,
process_group,
bias=bias,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
self.wv = ColumnParallelLinearTorch(
embed_dim,
self.kv_dim,
process_group,
bias=bias,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_cross_attn_causal = causal
self.inner_cross_attn_softmax_scale = softmax_scale
self.inner_cross_attn_dropout = dropout
# output projection always have the bias (for now)
self.wo = RowParallelLinearTorch(
embed_dim,
embed_dim,
process_group,
bias=bias,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["wo", "wq", "wk", "wv"]:
for param in getattr(self, name).parameters():
setattr(param, IS_TENSOR_PARALLEL, True)
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
else:
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
bsz, _, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
if seqlen is None:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
else:
q = rearrange(q, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim)
k = rearrange(k, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim)
v = rearrange(v, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim)
if not self.rot_embed_HF_impl:
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
if inference_params is None:
if self.rotary_emb_dim > 0:
q = self.rotary_emb._single_eval_forward(q)
k = self.rotary_emb._single_eval_forward(k)
kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
if self.dtype is torch.float32 and self.use_flash_attn:
if q.dtype not in [torch.float16, torch.bfloat16]:
q = q.to(torch.bfloat16)
if kv.dtype not in [torch.float16, torch.bfloat16]:
kv = kv.to(torch.bfloat16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
context = self.inner_cross_attn(q, kv).to(self.dtype)
else:
context = self.inner_cross_attn(q, kv)
else:
assert self.rotary_emb_dim > 0
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
moved_q = q.clone()
moved_k = k.clone()
if inference_params.sequence_len_offset == 0:
for i in range(len(empties)):
if empties[i] != 0:
moved_q[i][: -empties[i]] = q[i][empties[i] :]
moved_k[i][: -empties[i]] = k[i][empties[i] :]
moved_q = self.rotary_emb._single_eval_forward(
moved_q, seqlen_offset=inference_params.sequence_len_offset
)
moved_k = self.rotary_emb._single_eval_forward(
moved_k, seqlen_offset=inference_params.sequence_len_offset
)
for i in range(len(empties)):
if empties[i] != 0:
q[i][empties[i] :] = moved_q[i][: -empties[i]]
k[i][empties[i] :] = moved_k[i][: -empties[i]]
else:
q[i] = moved_q[i]
k[i] = moved_k[i]
else:
q = q.squeeze(1)
k = k.squeeze(1)
q = self.rotary_emb._single_forward(
q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
).unsqueeze(1)
k = self.rotary_emb._single_forward(
k,
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties,
).unsqueeze(1)
else:
raise NotImplementedError(
"You should make sure you are aware that you are changing the method of generating."
"According to your generation function instead of inference/seq_generator_module.py, "
"You may implement here for normal running."
)
kv = torch.stack([k, v], dim=2)
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
if hasattr(inference_params, "window_size") and inference_params.window_size is not None:
if inference_params.window_size <= inference_params.sequence_len_offset:
assert kv.size(1) == 1, "update kv lenth more than 1"
inference_params.key_value_memory_dict[self.layer_idx][
:, inference_params.keep_first : inference_params.window_size - 1, ...
] = inference_params.key_value_memory_dict[self.layer_idx][
:, -(inference_params.window_size - 1 - inference_params.keep_first) :, ...
].clone()
inference_params.real_sequence_len_offset = inference_params.sequence_len_offset
inference_params.sequence_len_offset = inference_params.window_size - 1
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
inference_params.sequence_len_offset = inference_params.real_sequence_len_offset
else:
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
else:
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
# When using FP16, there is a high probability of NAN in the KV.
# Since NAN cannot be removed by multiplying with and 0, it needs
# to be removed manually here.
kv = torch.where(torch.isnan(kv), 0, kv)
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
assert self.use_flash_attn is True
if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
attn_mask = inference_params.attention_mask[:, None, ...]
attn_mask = torch.logical_or(
torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask
)
attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1)
cu_seqlens = torch.concat(
[
torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device),
attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32),
],
dim=0,
)
cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32)
max_seqlen_q = attn_mask4flsh.shape[-1]
max_seqlen_k = attn_mask4flsh.shape[-1]
total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view(
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
)
if self.dtype is torch.float32:
if total_q.dtype not in [torch.float16, torch.bfloat16]:
total_q = total_q.to(torch.bfloat16)
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
total_kv = total_kv.to(torch.bfloat16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = FlashAttnVarlenKVPackedFunc.apply(
total_q,
total_kv,
cu_seqlens,
cu_seqlens,
max_seqlen_q,
max_seqlen_k,
0.0,
None,
True,
False,
).to(self.dtype)
else:
output = FlashAttnVarlenKVPackedFunc.apply(
total_q,
total_kv,
cu_seqlens,
cu_seqlens,
max_seqlen_q,
max_seqlen_k,
0.0,
None,
True,
False,
)
context = torch.zeros_like(q)
context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)
else:
attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1)
if hasattr(inference_params, "window_size") and inference_params.window_size is not None:
if inference_params.window_size <= inference_params.sequence_len_offset:
attn_mask = torch.concat(
[
attn_mask[..., : inference_params.keep_first],
attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :],
],
dim=-1,
)
k, v = torch.chunk(kv, 2, dim=2)
k = k.squeeze(2)
v = v.squeeze(2)
sp = k.shape
expansion = q.size(2) // k.size(2)
scores = torch.einsum(
"blhd,bnhd->bhln",
q,
k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
) / math.sqrt(q.size(-1))
scores = scores.masked_fill(attn_mask, -65000.0)
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
context = torch.einsum(
"bhmn,bnhd->bmhd",
scores,
v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
)
else:
if self.dtype is torch.float32 and self.use_flash_attn:
if q.dtype not in [torch.float16, torch.bfloat16]:
q = q.to(torch.bfloat16)
if kv.dtype not in [torch.float16, torch.bfloat16]:
kv = kv.to(torch.bfloat16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
context = self.inner_cross_attn(q, kv, causal=True).to(self.dtype)
else:
context = self.inner_cross_attn(q, kv, causal=True)
if seqlen is None:
context = rearrange(context, "b s h d -> b s (h d)")
else:
context = rearrange(context, "b s h d -> (b s) (h d)")
out = self.wo(context)
return out
def _packed_forward(self, x, inference_params=None, **kwargs):
"""
we delete seqlen=None for lint check, cause this arg is not used.
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
assert self.use_flash_attn is True
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "t (h d) -> t h d", d=self.head_dim)
k = rearrange(k, "t (h d) -> t h d", d=self.head_dim)
v = rearrange(v, "t (h d) -> t h d", d=self.head_dim)
# qkv shift
# the rotary embedding in flash attention module in performed by separating the front and back parts, while
# most of others are done by odd-even methods.
if not self.rot_embed_HF_impl:
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
indexes = kwargs.pop("indexes")
q = self.rotary_emb._single_forward(q, indexes=indexes)
k = self.rotary_emb._single_forward(k, indexes=indexes)
if inference_params is None:
kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1)
if self.dtype is torch.float32:
if q.dtype not in [torch.float16, torch.bfloat16]:
q = q.to(torch.bfloat16)
if kv.dtype not in [torch.float16, torch.bfloat16]:
kv = kv.to(torch.bfloat16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
context = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=kwargs["cu_seqlens"],
cu_seqlens_k=kwargs["cu_seqlens"],
max_seqlen_q=kwargs["max_seqlen"],
max_seqlen_k=kwargs["max_seqlen"],
dropout_p=self.inner_cross_attn_dropout,
softmax_scale=self.inner_cross_attn_softmax_scale,
causal=self.inner_cross_attn_causal,
).to(self.dtype)
else:
context = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=kwargs["cu_seqlens"],
cu_seqlens_k=kwargs["cu_seqlens"],
max_seqlen_q=kwargs["max_seqlen"],
max_seqlen_k=kwargs["max_seqlen"],
dropout_p=self.inner_cross_attn_dropout,
softmax_scale=self.inner_cross_attn_softmax_scale,
causal=self.inner_cross_attn_causal,
)
else:
raise RuntimeError("Not support this right now")
context = rearrange(context, "b h d -> b (h d)") # recover shape
out = self.wo(context)
return out
class PackedFlashLlamaLayer1D(nn.Module):
"""
1D Packed Flash Llama Layer.
Args:
hidden_size (int): The hidden size of model. 768 by default.
num_attention_heads (int): The number of attention heads. 12 by default.
num_kv_attention_heads (int): The number of kv attention heads. 12 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
dtype (torch.dtype): Type of data. torch.float by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
layer_idx (int): The index of current layer. 0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
device (Optional[Union[str, torch.device]]): The device will be used.
apply_post_layer_norm (bool): Whether use post layer norm. False by default.
fused_dropout_add_ln (bool): Whether use fused dropout add ln. True by default.
no_bias (bool): Whether remove bias. False by default.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
adapt_hf (bool): Whether adapt hf. False by default.
dropout_selective_checkpoint (bool): Whether use dropout selective checkpoint. True by default.
use_scaled_init (bool): Whether use scaled init. True by default.
use_swiglu (bool): Whether use swiglu. True by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
otherwise init fc1 weight in ffn. 0.02 by default,
ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
init_type (str): Initialization type. Use uniform or normal. "normal" by default,
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
def __init__(
self,
hidden_size: int = 768,
num_attention_heads: int = 12,
num_kv_attention_heads: int = 12,
mlp_ratio: int = 4,
attn_drop_rate: float = 0,
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
layer_norm_epsilon: float = 1e-6,
checkpoint: bool = False,
layer_idx: int = 0,
residual_in_fp32: bool = False,
device: Optional[torch.device] = None,
apply_post_layer_norm: bool = False,
fused_dropout_add_ln: bool = True,
no_bias: bool = False,
norm_type: str = "rmsnorm",
adapt_hf: bool = False,
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
attn_wqkv_init_std: float = 0.02,
attn_other_init_std: float = 0.02,
ffn_uplayer_init_std: float = 0.02,
ffn_other_init_std: float = 0.02,
init_type: str = "normal",
rope_base: int = 10000,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
self.checkpoint = checkpoint
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
self.layer_idx = layer_idx
self.use_flash_attn = use_flash_attn
self.prenorm = not apply_post_layer_norm
assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
self.fused_dropout_add_ln = fused_dropout_add_ln
self.attn_wqkv_init_std = attn_wqkv_init_std
self.attn_other_init_std = attn_other_init_std
self.ffn_uplayer_init_std = ffn_uplayer_init_std
self.ffn_other_init_std = ffn_other_init_std
head_dim = hidden_size // num_attention_heads
self.attention = MHA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
num_kv_heads=num_kv_attention_heads,
process_group=gpc.get_group(ParallelMode.TENSOR),
dropout=attn_drop_rate,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
rot_embed_HF_impl=adapt_hf,
bias=not no_bias,
rope_base=rope_base,
)
self.dropout1 = nn.Dropout(drop_rate)
if norm_type == "rmsnorm":
self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.attention_norm)
set_fp32_attr_to_module(self.ffn_norm)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed"
assert isinstance(self.attention_norm, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)
self.num_experts = num_experts
self.moe_gate_k = moe_gate_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_eval_capacity_factor = moe_eval_capacity_factor
self.moe_min_capacity = moe_min_capacity
self.moe_noisy_gate_policy = moe_noisy_gate_policy
self.moe_drop_tokens = moe_drop_tokens
self.moe_use_rts = moe_use_rts
self.moe_use_residual = moe_use_residual
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
if num_experts <= 1: # dense, not MoE
if use_swiglu:
self.feed_forward = FeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
else:
self.feed_forward = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
for _, param in self.feed_forward.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
else:
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
self.feed_forward = MoE(
hidden_size=hidden_size,
num_experts=num_experts,
ep_size=ep_size,
k=moe_gate_k,
capacity_factor=moe_capacity_factor,
eval_capacity_factor=moe_eval_capacity_factor,
min_capacity=moe_min_capacity,
noisy_gate_policy=moe_noisy_gate_policy,
drop_tokens=moe_drop_tokens,
use_rts=moe_use_rts,
use_residual=moe_use_residual,
device=device,
dtype=dtype,
)
for _, param in self.feed_forward.moe_layer.experts.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
set_fp32_attr_to_module(self.feed_forward.moe_layer.gate)
for param in self.attention_norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
for param in self.ffn_norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
self.return_residual = False
if init_type == "normal":
self.init_func = normal_
self.scaled_init_func = scaled_init_method_normal
else:
self.init_func = uniform_
self.scaled_init_func = scaled_init_method_uniform
self.reset_parameters()
def reset_parameters(self):
with torch.no_grad():
for name, param in self.attention.named_parameters():
if param.ndim == 1:
param.data.zero_()
elif "wq" in name or "wk" in name or "wv" in name:
self.init_func(std=self.attn_wqkv_init_std)(param.data)
elif self.use_scaled_init: # wo
self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
else:
self.init_func(std=self.attn_other_init_std)(param.data)
for name, param in self.feed_forward.named_parameters():
if self.use_swiglu:
if self.use_scaled_init and "w2" in name:
self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
else:
self.init_func(
std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
)(param.data)
else:
if self.use_scaled_init and "fc1" not in name:
self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
else:
self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
param.data
)
def forward(
self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None
):
if self.checkpoint and self.training:
return activation_checkpoint(
self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen
)
else:
return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen)
def _forward(
self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Attn/MLP(LN(residual))
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
if self.prenorm:
def _dropout_and_norm_attn(_residual, _hidden_states):
_dropped = self.dropout1(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
return _residual, _hidden_states
if self.dropout_selective_checkpoint:
residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
else:
residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
mixer_kwargs = {
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
"indexes": indexes,
"inference_params": inference_params,
}
hidden_states = self.attention(hidden_states, **mixer_kwargs)
if not isinstance(self.feed_forward, nn.Identity):
if not self.fused_dropout_add_ln:
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.ffn_norm(_residual.to(torch.float32))
return _residual, _hidden_states
if self.dropout_selective_checkpoint:
residual, hidden_states = activation_checkpoint(
_dropout_and_norm_ffn, False, residual, hidden_states
)
else:
residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
# MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1: # dense mlp output
hidden_states = self.feed_forward(hidden_states)
else: # MoE output
hidden_states, moe_loss, _ = self.feed_forward(hidden_states)
return hidden_states + residual, moe_loss
else:
assert residual is None
mixer_kwargs = {
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
"indexes": indexes,
"inference_params": inference_params,
}
mixer_out = self.attention(hidden_states, **mixer_kwargs)
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
dtype=self.attention_norm.weight.dtype
)
if not isinstance(self.feed_forward, nn.Identity):
# MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1: # dense mlp output
mlp_out = self.feed_forward(hidden_states)
else: # MoE output
mlp_out, moe_loss, _ = self.feed_forward(hidden_states)
if self.return_residual: # mlp out is actually a pair here
# NOTE: should not be here
mlp_out, hidden_states = mlp_out
hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
dtype=self.ffn_norm.weight.dtype
)
return hidden_states, moe_loss
class PackedFlashLlama1D(nn.Module):
"""
1D Packed Flash Llama.
Args:
num_layers (int): The number of layer. 12 by default.
hidden_size (int): The size of hidden state. 768 by default.
num_attention_heads (int): The number of attention head. 12 by default.
num_kv_attention_heads (int): The number of kv attention head. 12 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
dtype (torch.dtype): The type of data. torch.float by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
of layers. 1.0 by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
first (bool): Whether input embedding layer or not. False by default.
last (bool): Whether output embedding layer or not. False by default.
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
True by default.
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
apply_post_layer_norm (bool): Whether use post layer norm. False by default.
no_bias (bool): Whether remove bias. False by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
adapt_hf (bool): Whether adapt hf. False by default.
is_reward (bool): Whether use is_reward. False by default.
dropout_selective_checkpoint (bool): Whether dropout selective checkpoint. True by default.
use_scaled_init (bool): Whether use scaled init. True by default.
use_swiglu (bool): Whether use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
embedding_init_std (float): std used to init embedding weight. 0.02 by default,
attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
otherwise init fc1 weight in ffn. 0.02 by default,
ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
init_type (str): Initialization type. Use uniform or normal. "normal" by default,
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
def __init__(
self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
num_kv_attention_heads: int = 12,
vocab_size: int = 50304,
mlp_ratio: int = 4,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
checkpoint_fraction: float = 1.0,
layer_norm_epsilon: float = 1e-5,
first: bool = False,
last: bool = False,
embed_split_hidden: bool = False,
embed_grad_scale: float = 0.1,
parallel_output: bool = True,
start_layer_idx: int = 0,
device: Optional[torch.device] = None,
apply_post_layer_norm=False,
no_bias=False,
residual_in_fp32: bool = False,
norm_type: str = "rmsnorm",
adapt_hf: bool = False,
is_reward: bool = False,
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
embedding_init_std: float = 0.02,
attn_wqkv_init_std: float = 0.02,
attn_other_init_std: float = 0.02,
ffn_uplayer_init_std: float = 0.02,
ffn_other_init_std: float = 0.02,
out_head_init_std: float = 0.02,
init_type: str = "normal",
rope_base: int = 10000,
num_experts: bool = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
self.use_flash_attn = use_flash_attn
if checkpoint_fraction <= 0:
checkpoint = False
if not checkpoint:
checkpoint_fraction = 0
checkpoint_layer_num = num_layers * checkpoint_fraction
sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = ScaleColumnParallelLinear
if first:
if embed_split_hidden:
self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
else:
self.tok_embeddings = ParallelGPT2Embeddings(
embed_dim=hidden_size,
vocab_size=vocab_size,
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
for _, param in self.tok_embeddings.named_parameters():
if init_type == "normal":
normal_(std=embedding_init_std)(param)
else:
uniform_(std=embedding_init_std)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.embed_grad_scale = embed_grad_scale
self.layers = nn.ModuleList(
[
PackedFlashLlamaLayer1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_kv_attention_heads=num_kv_attention_heads,
mlp_ratio=mlp_ratio,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
dtype=dtype,
layer_norm_epsilon=layer_norm_epsilon,
checkpoint=lid < checkpoint_layer_num,
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
residual_in_fp32=residual_in_fp32,
device=device,
apply_post_layer_norm=apply_post_layer_norm,
fused_dropout_add_ln=False,
no_bias=no_bias,
norm_type=norm_type,
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
adapt_hf=adapt_hf,
attn_wqkv_init_std=attn_wqkv_init_std,
attn_other_init_std=attn_other_init_std,
ffn_uplayer_init_std=ffn_uplayer_init_std,
ffn_other_init_std=ffn_other_init_std,
init_type=init_type,
rope_base=rope_base,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
for lid in range(num_layers)
]
)
if last:
if not apply_post_layer_norm:
if norm_type == "rmsnorm":
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
for param in self.norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.output = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
)
for _, param in self.output.named_parameters():
if init_type == "normal":
normal_(std=out_head_init_std)(param)
else:
uniform_(std=out_head_init_std)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
if hasattr(self, "tok_embeddings"):
hidden_states = self.tok_embeddings(input_ids)
if self.embed_grad_scale != 1:
hidden_states = (
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
)
if isinstance(cu_seqlens, list):
assert len(cu_seqlens) == 1
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.squeeze(0)
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed init indicated a packed state
# the batch dimension with a size of 1 should be directly squeezed off.
if indexes is not None:
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
moe_losses = []
for _, block in enumerate(self.layers):
hidden_states, mos_loss = block(
hidden_states,
residual=None,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=inference_params,
max_seqlen=max_seqlen,
)
moe_losses.append(mos_loss)
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
extra_hidden_states_list = None
if hasattr(self, "output"):
hidden_states = self.output(hidden_states)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
if extra_hidden_states_list is not None:
extra_hidden_states_list = [
gather_forward_split_backward(extra_hidden_states, ParallelMode.TENSOR, dim=-1)
for extra_hidden_states in extra_hidden_states_list # pylint: disable=E1133
]
if extra_hidden_states_list is not None:
return (hidden_states, extra_hidden_states_list)
return hidden_states, moe_losses
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
"""
build generic model 1d
Args:
num_layers (int): The number of layer.
num_chunks (int): The number of partitions in pipeline parallel.
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
"""
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
if gpc.is_rank_for_log():
logger.info(f"The layer sharding is {all_parts}.")
models = []
kwargs["checkpoint_fraction"] = 1.0
start_idx, end_idx = 0, 0
for start, end in parts:
start_idx, end_idx = start, end
kwargs["num_layers"] = end - start
kwargs["first"] = start == 0
# If there is no content in the final layer, assign the last layer.
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
kwargs["device"] = device
kwargs["start_layer_idx"] = start
chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device)
models.append(chunk)
torch.distributed.barrier()
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
setattr(model, "first_layer", start_idx)
setattr(model, "last_layer", end_idx)
return model
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
def build_model_with_moe_cfg(
num_chunks=1,
checkpoint=False,
dtype=torch.float,
embed_split_hidden=False,
num_layers=48,
hidden_size=2048,
vocab_size=50304,
embed_grad_scale=1,
parallel_output=True,
num_attention_heads=32,
num_kv_attention_heads=None,
mlp_ratio=4.0,
residual_in_fp32=False,
norm_type="rmsnorm",
adapt_hf=False,
drop_rate=0,
attn_drop_rate=0,
apply_post_layer_norm=False, # pylint: disable=W0613
no_bias=False,
deepnorm=False, # pylint: disable=W0613
layer_norm_epsilon=1e-5,
is_reward=False,
dropout_selective_checkpoint=True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
embedding_init_std: float = 0.02,
attn_wqkv_init_std: float = 0.02,
attn_other_init_std: float = 0.02,
ffn_uplayer_init_std: float = 0.02,
ffn_other_init_std: float = 0.02,
out_head_init_std: float = 0.02,
init_type: str = "normal",
rope_base: int = 10000,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
"""
Build model with config.
Args:
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
dtype (torch.dtype): The type of data. torch.float by default.
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
False by default.
num_layers (int): The number of layer. 48 by default.
hidden_size (int): The size of hidden state. 2048 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
num_attention_heads (int): The number of attention head. 32 by default.
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
because this parameter requires inconsistent data types to be passed between pipelines,
which requires significant modifications to internlm.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
drop_rate (float): The dropout rate of input hidden state. 0 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
is_reward (bool): Whether to use reward model. False by default.
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
cfg = dict(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
vocab_size=vocab_size,
embed_grad_scale=embed_grad_scale,
parallel_output=parallel_output,
mlp_ratio=mlp_ratio,
apply_post_layer_norm=apply_post_layer_norm,
no_bias=no_bias,
residual_in_fp32=residual_in_fp32,
norm_type=norm_type,
adapt_hf=adapt_hf,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
layer_norm_epsilon=layer_norm_epsilon,
is_reward=is_reward,
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
embedding_init_std=embedding_init_std,
attn_wqkv_init_std=attn_wqkv_init_std,
attn_other_init_std=attn_other_init_std,
ffn_uplayer_init_std=ffn_uplayer_init_std,
ffn_other_init_std=ffn_other_init_std,
out_head_init_std=out_head_init_std,
init_type=init_type,
rope_base=rope_base,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)