mirror of https://github.com/InternLM/InternLM
refactor moe layer
parent
c423f1159b
commit
7cec7e985f
|
@ -143,7 +143,7 @@ model = dict(
|
||||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||||
num_experts=4,
|
num_experts=4,
|
||||||
moe_use_residual=False,
|
moe_use_residual=False,
|
||||||
moe_gate_k=2,
|
moe_type="GShard",
|
||||||
)
|
)
|
||||||
|
|
||||||
# zero1 parallel:
|
# zero1 parallel:
|
||||||
|
@ -176,6 +176,17 @@ monitor = dict(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# custom moe impl configs
|
||||||
|
moe = dict(
|
||||||
|
top_k=2,
|
||||||
|
capacity_factor=1.0,
|
||||||
|
eval_capacity_factor=1.0,
|
||||||
|
min_capacity=4,
|
||||||
|
noisy_gate_policy=None,
|
||||||
|
drop_tokens=True,
|
||||||
|
use_rts=True,
|
||||||
|
)
|
||||||
|
|
||||||
model_type = "INTERNLM_MoE"
|
model_type = "INTERNLM_MoE"
|
||||||
|
|
||||||
# metric_dtype can be "fp32" or other string
|
# metric_dtype can be "fp32" or other string
|
||||||
|
|
|
@ -305,8 +305,8 @@ def args_sanity_check():
|
||||||
model._add_item("num_experts", 1)
|
model._add_item("num_experts", 1)
|
||||||
if "moe_use_residual" not in model:
|
if "moe_use_residual" not in model:
|
||||||
model._add_item("moe_use_residual", False)
|
model._add_item("moe_use_residual", False)
|
||||||
if "moe_gate_k" not in model:
|
if "moe_type" not in model:
|
||||||
model._add_item("moe_gate_k", 2)
|
model._add_item("moe_type", None)
|
||||||
# process the parallel config
|
# process the parallel config
|
||||||
if "sequence_parallel" not in gpc.config.parallel:
|
if "sequence_parallel" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||||
|
|
|
@ -53,16 +53,9 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
|
||||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
|
||||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
|
||||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
|
||||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
|
||||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
|
|
||||||
infinite capacity).
|
|
||||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
|
||||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||||
(https://arxiv.org/abs/2201.05596) layer.
|
(https://arxiv.org/abs/2201.05596) layer.
|
||||||
|
moe_type (str): determine which moe impl will be used, default is GShardMoE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -158,6 +151,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.mlp = MoE(
|
self.mlp = MoE(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
|
ep_group=gpc.get_group(ParallelMode.EXPERT),
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -292,16 +286,9 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
|
||||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
|
||||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
|
||||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
|
||||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
|
||||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
|
||||||
to infinite capacity).
|
|
||||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
|
||||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||||
(https://arxiv.org/abs/2201.05596) layer.
|
(https://arxiv.org/abs/2201.05596) layer.
|
||||||
|
moe_type (str): determine which moe impl will be used, default is GShardMoE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -519,13 +506,6 @@ def build_model_with_moe_cfg(
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
moe_gate_k: int = 1, # pylint: disable=W0613
|
|
||||||
moe_capacity_factor: float = 1.0, # pylint: disable=W0613
|
|
||||||
moe_eval_capacity_factor: float = 1.0, # pylint: disable=W0613
|
|
||||||
moe_min_capacity: int = 4, # pylint: disable=W0613
|
|
||||||
moe_noisy_gate_policy: str = None, # pylint: disable=W0613
|
|
||||||
moe_drop_tokens: bool = True, # pylint: disable=W0613
|
|
||||||
moe_use_rts: bool = True, # pylint: disable=W0613
|
|
||||||
moe_use_residual: bool = False, # pylint: disable=W0613
|
moe_use_residual: bool = False, # pylint: disable=W0613
|
||||||
moe_type: str = None, # pylint: disable=W0613
|
moe_type: str = None, # pylint: disable=W0613
|
||||||
):
|
):
|
||||||
|
@ -559,16 +539,9 @@ def build_model_with_moe_cfg(
|
||||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
|
||||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
|
||||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
|
||||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
|
||||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
|
||||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
|
||||||
to infinite capacity).
|
|
||||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
|
||||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||||
(https://arxiv.org/abs/2201.05596) layer.
|
(https://arxiv.org/abs/2201.05596) layer.
|
||||||
|
moe_type (str): determine which moe impl will be used, default is GShardMoE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cfg = dict(
|
cfg = dict(
|
||||||
|
|
|
@ -36,6 +36,7 @@ class MoE(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_experts=1,
|
num_experts=1,
|
||||||
|
ep_group=None,
|
||||||
ep_size=1,
|
ep_size=1,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
@ -43,27 +44,23 @@ class MoE(torch.nn.Module):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert (
|
moe_impl = self.get_moe(getattr(gpc.config.model, "moe_type", None))
|
||||||
num_experts % ep_size == 0
|
|
||||||
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
|
|
||||||
self.ep_size = ep_size
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.num_local_experts = num_experts // self.ep_size
|
|
||||||
|
|
||||||
moe_type = getattr(gpc.config.model, "moe_type", None)
|
if not hasattr(gpc.config, "moe"):
|
||||||
|
gpc.config.moe = dict()
|
||||||
|
|
||||||
if moe_type is None or moe_type == "GShard":
|
self.moe_layer = moe_impl(
|
||||||
self.moe_layer = GShardMOELayer(
|
hidden_size=hidden_size,
|
||||||
hidden_size,
|
num_experts=num_experts,
|
||||||
gpc.get_group(ParallelMode.EXPERT),
|
ep_group=ep_group,
|
||||||
ep_size,
|
ep_size=ep_size,
|
||||||
num_experts,
|
device=device,
|
||||||
device,
|
dtype=dtype,
|
||||||
dtype,
|
**(gpc.config.moe)
|
||||||
)
|
)
|
||||||
|
|
||||||
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
||||||
self.use_residual = getattr(gpc.config.model, "moe_use_residual", False)
|
self.use_residual = gpc.config.model.moe_use_residual
|
||||||
if self.use_residual:
|
if self.use_residual:
|
||||||
self.residual_mlp = FeedForward(
|
self.residual_mlp = FeedForward(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -77,6 +74,10 @@ class MoE(torch.nn.Module):
|
||||||
# coefficient is used for weighted sum of the output of expert and residual mlp
|
# coefficient is used for weighted sum of the output of expert and residual mlp
|
||||||
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||||
|
|
||||||
|
def get_moe(self, moe_type):
|
||||||
|
if moe_type is None or moe_type == "GShard":
|
||||||
|
return GShardMOELayer
|
||||||
|
|
||||||
def forward(self, hidden_states, used_token=None):
|
def forward(self, hidden_states, used_token=None):
|
||||||
"""MoE forward
|
"""MoE forward
|
||||||
|
|
||||||
|
|
|
@ -385,27 +385,36 @@ class GShardMOELayer(BaseMoELayer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
num_experts: int,
|
||||||
ep_group,
|
ep_group,
|
||||||
ep_size: int,
|
ep_size: int,
|
||||||
num_experts: int,
|
top_k: int = 1,
|
||||||
|
capacity_factor: float = 1.0,
|
||||||
|
eval_capacity_factor: float = 1.0,
|
||||||
|
min_capacity: int = 4,
|
||||||
|
noisy_gate_policy: str = None,
|
||||||
|
drop_tokens: bool = True,
|
||||||
|
use_rts: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
noisy_gate_policy = getattr(gpc.config.model, "noisy_gate_policy", None)
|
|
||||||
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
||||||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
num_experts % ep_size == 0
|
||||||
|
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TopKGate(
|
TopKGate(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
topk=getattr(gpc.config.model, "moe_gate_k", 1),
|
top_k,
|
||||||
capacity_factor=getattr(gpc.config.model, "moe_capacity_factor", 1.0),
|
capacity_factor,
|
||||||
eval_capacity_factor=getattr(gpc.config.model, "moe_eval_capacity_factor", 1.0),
|
eval_capacity_factor,
|
||||||
min_capacity=getattr(gpc.config.model, "moe_min_capacity", 4),
|
min_capacity,
|
||||||
noisy_gate_policy=getattr(gpc.config.model, "moe_noisy_gate_policy", None),
|
noisy_gate_policy,
|
||||||
drop_tokens=getattr(gpc.config.model, "moe_drop_tokens", True),
|
drop_tokens,
|
||||||
use_rts=getattr(gpc.config.model, "moe_use_rts", True),
|
use_rts,
|
||||||
),
|
),
|
||||||
torch.nn.ModuleList(
|
torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue