refactor moe layer

pull/567/head
Wenwen Qu 2024-01-10 15:39:16 +08:00
parent c423f1159b
commit 7cec7e985f
5 changed files with 54 additions and 60 deletions

View File

@ -143,7 +143,7 @@ model = dict(
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=4,
moe_use_residual=False,
moe_gate_k=2,
moe_type="GShard",
)
# zero1 parallel:
@ -176,6 +176,17 @@ monitor = dict(
),
)
# custom moe impl configs
moe = dict(
top_k=2,
capacity_factor=1.0,
eval_capacity_factor=1.0,
min_capacity=4,
noisy_gate_policy=None,
drop_tokens=True,
use_rts=True,
)
model_type = "INTERNLM_MoE"
# metric_dtype can be "fp32" or other string

View File

@ -305,8 +305,8 @@ def args_sanity_check():
model._add_item("num_experts", 1)
if "moe_use_residual" not in model:
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
if "moe_type" not in model:
model._add_item("moe_type", None)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)

View File

@ -53,16 +53,9 @@ class PackedFlashBaseLayer1D(nn.Module):
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
moe_type (str): determine which moe impl will be used, default is GShardMoE
"""
def __init__(
@ -158,6 +151,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.mlp = MoE(
hidden_size=hidden_size,
num_experts=num_experts,
ep_group=gpc.get_group(ParallelMode.EXPERT),
ep_size=ep_size,
device=device,
dtype=dtype,
@ -292,16 +286,9 @@ class PackedFlashInternLm1D(nn.Module):
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
moe_type (str): determine which moe impl will be used, default is GShardMoE
"""
def __init__(
@ -519,13 +506,6 @@ def build_model_with_moe_cfg(
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1, # pylint: disable=W0613
moe_capacity_factor: float = 1.0, # pylint: disable=W0613
moe_eval_capacity_factor: float = 1.0, # pylint: disable=W0613
moe_min_capacity: int = 4, # pylint: disable=W0613
moe_noisy_gate_policy: str = None, # pylint: disable=W0613
moe_drop_tokens: bool = True, # pylint: disable=W0613
moe_use_rts: bool = True, # pylint: disable=W0613
moe_use_residual: bool = False, # pylint: disable=W0613
moe_type: str = None, # pylint: disable=W0613
):
@ -559,16 +539,9 @@ def build_model_with_moe_cfg(
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
moe_type (str): determine which moe impl will be used, default is GShardMoE
"""
cfg = dict(

View File

@ -36,6 +36,7 @@ class MoE(torch.nn.Module):
self,
hidden_size,
num_experts=1,
ep_group=None,
ep_size=1,
device=None,
dtype=None,
@ -43,27 +44,23 @@ class MoE(torch.nn.Module):
super().__init__()
assert (
num_experts % ep_size == 0
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
self.ep_size = ep_size
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size
moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None))
moe_type = getattr(gpc.config.model, "moe_type", None)
if not hasattr(gpc.config, "moe"):
gpc.config.moe = dict()
if moe_type is None or moe_type == "GShard":
self.moe_layer = GShardMOELayer(
hidden_size,
gpc.get_group(ParallelMode.EXPERT),
ep_size,
num_experts,
device,
dtype,
)
self.moe_layer = moe_impl(
hidden_size=hidden_size,
num_experts=num_experts,
ep_group=ep_group,
ep_size=ep_size,
device=device,
dtype=dtype,
**(gpc.config.moe)
)
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
self.use_residual = getattr(gpc.config.model, "moe_use_residual", False)
self.use_residual = gpc.config.model.moe_use_residual
if self.use_residual:
self.residual_mlp = FeedForward(
hidden_size,
@ -77,6 +74,10 @@ class MoE(torch.nn.Module):
# coefficient is used for weighted sum of the output of expert and residual mlp
self.coefficient = torch.nn.Linear(hidden_size, 2)
def get_moe(self, moe_type):
if moe_type is None or moe_type == "GShard":
return GShardMOELayer
def forward(self, hidden_states, used_token=None):
"""MoE forward

View File

@ -385,27 +385,36 @@ class GShardMOELayer(BaseMoELayer):
def __init__(
self,
hidden_size,
num_experts: int,
ep_group,
ep_size: int,
num_experts: int,
top_k: int = 1,
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 4,
noisy_gate_policy: str = None,
drop_tokens: bool = True,
use_rts: bool = True,
device=None,
dtype=None,
) -> None:
noisy_gate_policy = getattr(gpc.config.model, "noisy_gate_policy", None)
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
"Unsupported noisy_gate_policy: " + noisy_gate_policy
)
assert (
num_experts % ep_size == 0
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
super().__init__(
TopKGate(
hidden_size,
num_experts,
topk=getattr(gpc.config.model, "moe_gate_k", 1),
capacity_factor=getattr(gpc.config.model, "moe_capacity_factor", 1.0),
eval_capacity_factor=getattr(gpc.config.model, "moe_eval_capacity_factor", 1.0),
min_capacity=getattr(gpc.config.model, "moe_min_capacity", 4),
noisy_gate_policy=getattr(gpc.config.model, "moe_noisy_gate_policy", None),
drop_tokens=getattr(gpc.config.model, "moe_drop_tokens", True),
use_rts=getattr(gpc.config.model, "moe_use_rts", True),
top_k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts,
),
torch.nn.ModuleList(
[