mirror of https://github.com/InternLM/InternLM
reformat code
parent
4a5cf5d1df
commit
5f2e082b21
|
@ -57,9 +57,11 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
|
||||
infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -184,7 +186,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
device=torch.device("cuda"),
|
||||
dtype=torch.float,
|
||||
)
|
||||
|
||||
|
||||
self.mlp = MoE(
|
||||
hidden_size=hidden_size,
|
||||
experts=experts,
|
||||
|
@ -332,9 +334,11 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
||||
to infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -517,7 +521,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
logger.info(f"The layer sharding is {all_parts}.") # pylint: disable=W1203
|
||||
|
||||
models = []
|
||||
|
||||
|
@ -578,7 +582,7 @@ def build_model_with_cfg(
|
|||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
moe_use_residual: bool = True,
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
@ -615,9 +619,11 @@ def build_model_with_cfg(
|
|||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
||||
to infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
cfg = dict(
|
||||
|
|
|
@ -50,9 +50,13 @@ class MoE(torch.nn.Module):
|
|||
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'
|
||||
or 'None'.
|
||||
using_default_moe (bool, optional): default=True, whether to use the default MoE layer.
|
||||
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
|
||||
infinite capacity).
|
||||
use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
residual_mlp (torch.nn.Module, optional): default=None, the torch module that defines the residual MLP.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -70,7 +74,7 @@ class MoE(torch.nn.Module):
|
|||
use_rts: bool = True,
|
||||
using_default_moe: bool = True,
|
||||
use_residual=True,
|
||||
residual_mlp=None
|
||||
residual_mlp=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
@ -82,7 +86,7 @@ class MoE(torch.nn.Module):
|
|||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // self.ep_size
|
||||
|
||||
logger.info(
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
|
||||
f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
|
||||
)
|
||||
|
@ -136,11 +140,11 @@ class MoE(torch.nn.Module):
|
|||
"""
|
||||
output = self.moe_layer(hidden_states, used_token)
|
||||
if self.use_residual:
|
||||
# Residual MoE
|
||||
output_mlp = self.residual_mlp(hidden_states)
|
||||
if type(output_mlp) is tuple:
|
||||
output_mlp = output_mlp[0] # Ignore the bias term for now
|
||||
coef = self.coefficient(hidden_states)
|
||||
coef = torch.nn.functional.softmax(coef, dim=-1)
|
||||
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
|
||||
# Residual MoE
|
||||
output_mlp = self.residual_mlp(hidden_states)
|
||||
if isinstance(output_mlp, tuple):
|
||||
output_mlp = output_mlp[0] # Ignore the bias term for now
|
||||
coef = self.coefficient(hidden_states)
|
||||
coef = torch.nn.functional.softmax(coef, dim=-1)
|
||||
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
|
||||
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
||||
|
|
Loading…
Reference in New Issue