mirror of https://github.com/InternLM/InternLM
Merge pull request #1 from blankde/feature_add_moe_zl
add residual and other moe featurespull/375/head
commit
4a5cf5d1df
|
@ -52,6 +52,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||||
|
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||||
|
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||||
|
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||||
|
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||||
|
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||||
|
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||||
|
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||||
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -73,6 +81,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
|
moe_gate_k: int = 1,
|
||||||
|
moe_capacity_factor: float = 1.0,
|
||||||
|
moe_eval_capacity_factor: float = 1.0,
|
||||||
|
moe_min_capacity: int = 4,
|
||||||
|
moe_noisy_gate_policy: str = None,
|
||||||
|
moe_drop_tokens: bool = True,
|
||||||
|
moe_use_rts: bool = True,
|
||||||
|
moe_use_residual: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
@ -107,6 +123,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
|
|
||||||
# TODO: replace num_experts and epsize with function parameter
|
# TODO: replace num_experts and epsize with function parameter
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
|
self.moe_gate_k = moe_gate_k
|
||||||
|
self.moe_capacity_factor = moe_capacity_factor
|
||||||
|
self.moe_eval_capacity_factor = moe_eval_capacity_factor
|
||||||
|
self.moe_min_capacity = moe_min_capacity
|
||||||
|
self.moe_noisy_gate_policy = moe_noisy_gate_policy
|
||||||
|
self.moe_drop_tokens = moe_drop_tokens
|
||||||
|
self.moe_use_rts = moe_use_rts
|
||||||
|
self.moe_use_residual = moe_use_residual
|
||||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||||
if num_experts <= 1: # dense, not MoE
|
if num_experts <= 1: # dense, not MoE
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
|
@ -135,7 +159,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
expert = torch.nn.ModuleList(
|
experts = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
FeedForward(
|
FeedForward(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -149,9 +173,34 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
for i in range(num_experts // ep_size)
|
for i in range(num_experts // ep_size)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# TODO: test moe for now, need more parameter such as: capacity_factor,
|
|
||||||
# eval_capacity_factor, min_capacity, drop_tokens
|
if moe_use_residual:
|
||||||
self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1)
|
residual_mlp = FeedForward(
|
||||||
|
hidden_size,
|
||||||
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
|
out_features=hidden_size,
|
||||||
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
bias=False,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
dtype=torch.float,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = MoE(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
experts=experts,
|
||||||
|
num_experts=num_experts,
|
||||||
|
ep_size=ep_size,
|
||||||
|
k=moe_gate_k,
|
||||||
|
capacity_factor=moe_capacity_factor,
|
||||||
|
eval_capacity_factor=moe_eval_capacity_factor,
|
||||||
|
min_capacity=moe_min_capacity,
|
||||||
|
noisy_gate_policy=moe_noisy_gate_policy,
|
||||||
|
drop_tokens=moe_drop_tokens,
|
||||||
|
use_rts=moe_use_rts,
|
||||||
|
use_residual=moe_use_residual,
|
||||||
|
residual_mlp=residual_mlp if moe_use_residual else None,
|
||||||
|
)
|
||||||
|
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
self.use_swiglu = use_swiglu
|
self.use_swiglu = use_swiglu
|
||||||
self.use_scaled_init = use_scaled_init
|
self.use_scaled_init = use_scaled_init
|
||||||
|
@ -278,7 +327,14 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||||
|
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||||
|
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||||
|
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||||
|
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||||
|
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||||
|
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||||
|
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||||
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -309,6 +365,14 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
num_experts: bool = 1,
|
num_experts: bool = 1,
|
||||||
|
moe_gate_k: int = 1,
|
||||||
|
moe_capacity_factor: float = 1.0,
|
||||||
|
moe_eval_capacity_factor: float = 1.0,
|
||||||
|
moe_min_capacity: int = 4,
|
||||||
|
moe_noisy_gate_policy: str = None,
|
||||||
|
moe_drop_tokens: bool = True,
|
||||||
|
moe_use_rts: bool = True,
|
||||||
|
moe_use_residual: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -361,6 +425,14 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
|
moe_gate_k=moe_gate_k,
|
||||||
|
moe_capacity_factor=moe_capacity_factor,
|
||||||
|
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
||||||
|
moe_min_capacity=moe_min_capacity,
|
||||||
|
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
||||||
|
moe_drop_tokens=moe_drop_tokens,
|
||||||
|
moe_use_rts=moe_use_rts,
|
||||||
|
moe_use_residual=moe_use_residual,
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -499,6 +571,14 @@ def build_model_with_cfg(
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
sequence_parallel: bool = False, # pylint: disable=W0613
|
sequence_parallel: bool = False, # pylint: disable=W0613
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
|
moe_gate_k: int = 1,
|
||||||
|
moe_capacity_factor: float = 1.0,
|
||||||
|
moe_eval_capacity_factor: float = 1.0,
|
||||||
|
moe_min_capacity: int = 4,
|
||||||
|
moe_noisy_gate_policy: str = None,
|
||||||
|
moe_drop_tokens: bool = True,
|
||||||
|
moe_use_rts: bool = True,
|
||||||
|
moe_use_residual: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Builde model with config
|
Builde model with config
|
||||||
|
@ -530,7 +610,14 @@ def build_model_with_cfg(
|
||||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||||
|
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||||
|
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||||
|
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||||
|
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||||
|
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||||
|
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||||
|
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||||
|
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cfg = dict(
|
cfg = dict(
|
||||||
|
@ -555,6 +642,14 @@ def build_model_with_cfg(
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
sequence_parallel=sequence_parallel,
|
sequence_parallel=sequence_parallel,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
|
moe_gate_k=moe_gate_k,
|
||||||
|
moe_capacity_factor=moe_capacity_factor,
|
||||||
|
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
||||||
|
moe_min_capacity=moe_min_capacity,
|
||||||
|
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
||||||
|
moe_drop_tokens=moe_drop_tokens,
|
||||||
|
moe_use_rts=moe_use_rts,
|
||||||
|
moe_use_residual=moe_use_residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||||
|
|
|
@ -58,7 +58,7 @@ class MoE(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
expert,
|
experts,
|
||||||
num_experts=1,
|
num_experts=1,
|
||||||
ep_size=1,
|
ep_size=1,
|
||||||
k=1,
|
k=1,
|
||||||
|
@ -69,6 +69,8 @@ class MoE(torch.nn.Module):
|
||||||
drop_tokens: bool = True,
|
drop_tokens: bool = True,
|
||||||
use_rts: bool = True,
|
use_rts: bool = True,
|
||||||
using_default_moe: bool = True,
|
using_default_moe: bool = True,
|
||||||
|
use_residual=True,
|
||||||
|
residual_mlp=None
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -89,7 +91,7 @@ class MoE(torch.nn.Module):
|
||||||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||||
)
|
)
|
||||||
|
|
||||||
experts = Experts(expert, self.num_local_experts)
|
experts = Experts(experts, self.num_local_experts)
|
||||||
|
|
||||||
if using_default_moe:
|
if using_default_moe:
|
||||||
self.moe_layer = MOELayer(
|
self.moe_layer = MOELayer(
|
||||||
|
@ -110,6 +112,12 @@ class MoE(torch.nn.Module):
|
||||||
self.num_local_experts,
|
self.num_local_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_residual = use_residual
|
||||||
|
if use_residual:
|
||||||
|
self.residual_mlp = residual_mlp
|
||||||
|
# coefficient is used for weighted sum of the output of expert and mlp
|
||||||
|
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||||
|
|
||||||
def forward(self, hidden_states, used_token=None):
|
def forward(self, hidden_states, used_token=None):
|
||||||
"""MoE forward
|
"""MoE forward
|
||||||
|
|
||||||
|
@ -127,5 +135,12 @@ class MoE(torch.nn.Module):
|
||||||
* exp_counts (int): expert count
|
* exp_counts (int): expert count
|
||||||
"""
|
"""
|
||||||
output = self.moe_layer(hidden_states, used_token)
|
output = self.moe_layer(hidden_states, used_token)
|
||||||
|
if self.use_residual:
|
||||||
|
# Residual MoE
|
||||||
|
output_mlp = self.residual_mlp(hidden_states)
|
||||||
|
if type(output_mlp) is tuple:
|
||||||
|
output_mlp = output_mlp[0] # Ignore the bias term for now
|
||||||
|
coef = self.coefficient(hidden_states)
|
||||||
|
coef = torch.nn.functional.softmax(coef, dim=-1)
|
||||||
|
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
|
||||||
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
||||||
|
|
Loading…
Reference in New Issue