add residual and other moe features

pull/375/head
zhanglei 2023-08-09 14:14:18 +08:00
parent bc699ad46f
commit cdf3ed9533
2 changed files with 87 additions and 7 deletions

View File

@ -73,6 +73,14 @@ class PackedFlashBaseLayer1D(nn.Module):
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
self.checkpoint = checkpoint
@ -135,7 +143,7 @@ class PackedFlashBaseLayer1D(nn.Module):
dtype=dtype,
)
else:
expert = torch.nn.ModuleList(
experts = torch.nn.ModuleList(
[
FeedForward(
hidden_size,
@ -149,9 +157,34 @@ class PackedFlashBaseLayer1D(nn.Module):
for i in range(num_experts // ep_size)
]
)
# TODO: test moe for now, need more parameter such as: capacity_factor,
# eval_capacity_factor, min_capacity, drop_tokens
self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1)
if moe_use_residual:
residual_mlp = FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=torch.device("cuda"),
dtype=torch.float,
)
self.mlp = MoE(
hidden_size=hidden_size,
experts=experts,
num_experts=num_experts,
ep_size=ep_size,
k=moe_gate_k,
capacity_factor=moe_capacity_factor,
eval_capacity_factor=moe_eval_capacity_factor,
min_capacity=moe_min_capacity,
noisy_gate_policy=moe_noisy_gate_policy,
drop_tokens=moe_drop_tokens,
use_rts=moe_use_rts,
use_residual=moe_use_residual,
residual_mlp=residual_mlp if moe_use_residual else None,
)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
@ -309,6 +342,14 @@ class PackedFlashInternLm1D(nn.Module):
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: bool = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
@ -361,6 +402,14 @@ class PackedFlashInternLm1D(nn.Module):
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
for lid in range(num_layers)
]
@ -499,6 +548,14 @@ def build_model_with_cfg(
use_flash_attn: bool = True,
sequence_parallel: bool = False, # pylint: disable=W0613
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
"""
Builde model with config
@ -555,6 +612,14 @@ def build_model_with_cfg(
use_flash_attn=use_flash_attn,
sequence_parallel=sequence_parallel,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -58,7 +58,7 @@ class MoE(torch.nn.Module):
def __init__(
self,
hidden_size,
expert,
experts,
num_experts=1,
ep_size=1,
k=1,
@ -69,6 +69,8 @@ class MoE(torch.nn.Module):
drop_tokens: bool = True,
use_rts: bool = True,
using_default_moe: bool = True,
use_residual=True,
residual_mlp=None
):
super().__init__()
@ -89,7 +91,7 @@ class MoE(torch.nn.Module):
"Unsupported noisy_gate_policy: " + noisy_gate_policy
)
experts = Experts(expert, self.num_local_experts)
experts = Experts(experts, self.num_local_experts)
if using_default_moe:
self.moe_layer = MOELayer(
@ -110,6 +112,12 @@ class MoE(torch.nn.Module):
self.num_local_experts,
)
self.use_residual = use_residual
if use_residual:
self.residual_mlp = residual_mlp
# coefficient is used for weighted sum of the output of expert and mlp
self.coefficient = torch.nn.Linear(hidden_size, 2)
def forward(self, hidden_states, used_token=None):
"""MoE forward
@ -127,5 +135,12 @@ class MoE(torch.nn.Module):
* exp_counts (int): expert count
"""
output = self.moe_layer(hidden_states, used_token)
if self.use_residual:
# Residual MoE
output_mlp = self.residual_mlp(hidden_states)
if type(output_mlp) is tuple:
output_mlp = output_mlp[0] # Ignore the bias term for now
coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts