reformat code

pull/375/head
Wenwen Qu 2023-08-09 16:03:47 +08:00
parent 4a5cf5d1df
commit 5f2e082b21
2 changed files with 28 additions and 18 deletions

View File

@ -57,9 +57,11 @@ class PackedFlashBaseLayer1D(nn.Module):
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
""" """
def __init__( def __init__(
@ -184,7 +186,7 @@ class PackedFlashBaseLayer1D(nn.Module):
device=torch.device("cuda"), device=torch.device("cuda"),
dtype=torch.float, dtype=torch.float,
) )
self.mlp = MoE( self.mlp = MoE(
hidden_size=hidden_size, hidden_size=hidden_size,
experts=experts, experts=experts,
@ -332,9 +334,11 @@ class PackedFlashInternLm1D(nn.Module):
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
""" """
def __init__( def __init__(
@ -517,7 +521,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank] parts = all_parts[pipeline_rank]
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.info(f"The layer sharding is {all_parts}.") logger.info(f"The layer sharding is {all_parts}.") # pylint: disable=W1203
models = [] models = []
@ -578,7 +582,7 @@ def build_model_with_cfg(
moe_noisy_gate_policy: str = None, moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True, moe_drop_tokens: bool = True,
moe_use_rts: bool = True, moe_use_rts: bool = True,
moe_use_residual: bool = False, moe_use_residual: bool = True,
): ):
""" """
Builde model with config Builde model with config
@ -615,9 +619,11 @@ def build_model_with_cfg(
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'. moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
""" """
cfg = dict( cfg = dict(

View File

@ -50,9 +50,13 @@ class MoE(torch.nn.Module):
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'
or 'None'. or 'None'.
using_default_moe (bool, optional): default=True, whether to use the default MoE layer.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
infinite capacity). infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection. use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
residual_mlp (torch.nn.Module, optional): default=None, the torch module that defines the residual MLP.
""" """
def __init__( def __init__(
@ -70,7 +74,7 @@ class MoE(torch.nn.Module):
use_rts: bool = True, use_rts: bool = True,
using_default_moe: bool = True, using_default_moe: bool = True,
use_residual=True, use_residual=True,
residual_mlp=None residual_mlp=None,
): ):
super().__init__() super().__init__()
@ -82,7 +86,7 @@ class MoE(torch.nn.Module):
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size self.num_local_experts = num_experts // self.ep_size
logger.info( logger.info( # pylint: disable=W1203
f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:" f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}" f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
) )
@ -136,11 +140,11 @@ class MoE(torch.nn.Module):
""" """
output = self.moe_layer(hidden_states, used_token) output = self.moe_layer(hidden_states, used_token)
if self.use_residual: if self.use_residual:
# Residual MoE # Residual MoE
output_mlp = self.residual_mlp(hidden_states) output_mlp = self.residual_mlp(hidden_states)
if type(output_mlp) is tuple: if isinstance(output_mlp, tuple):
output_mlp = output_mlp[0] # Ignore the bias term for now output_mlp = output_mlp[0] # Ignore the bias term for now
coef = self.coefficient(hidden_states) coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=-1) coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts return output, self.moe_layer.l_aux, self.moe_layer.exp_counts