replace flashatten experts by feedforward experts

pull/182/head
Wenwen Qu 2023-09-08 18:04:57 +08:00
parent cd6b28b073
commit 6cf0fec314
3 changed files with 8 additions and 20 deletions

View File

@ -2,10 +2,10 @@ import typing
from typing import Dict, Tuple from typing import Dict, Tuple
import torch import torch
from flash_attn.modules.mlp import ParallelFusedMLP
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.moe.experts import Experts from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import MOELayer, TopKGate from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
@ -102,17 +102,12 @@ class MoE(torch.nn.Module):
experts = torch.nn.ModuleList( experts = torch.nn.ModuleList(
[ [
# TODO have trouble when use internlm.model.linear.FeedForward # TODO have trouble when use internlm.model.linear.FeedForward
ParallelFusedMLP( FeedForward(
hidden_size, hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio), int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size, out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR), process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False, bias=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -143,17 +138,12 @@ class MoE(torch.nn.Module):
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
self.use_residual = use_residual self.use_residual = use_residual
if use_residual: if use_residual:
self.residual_mlp = ParallelFusedMLP( self.residual_mlp = FeedForward(
hidden_size, hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio), int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size, out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR), process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False, bias=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -188,9 +178,7 @@ class MoE(torch.nn.Module):
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
def split_params_into_different_moe_groups_for_optimizer( def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]:
param_groups: Tuple[Dict], max_group_size=178956971
) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer """Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name Compatiable with muiltiple param groups, each should have a name

View File

@ -37,7 +37,7 @@ class Experts(torch.nn.Module):
for expert in self.experts: for expert in self.experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for _, param in expert.named_parameters(): for _, param in expert.named_parameters():
param.belong_expert = True param.is_expert = True
param.group_name = expert_group_name param.group_name = expert_group_name
def forward(self, inputs): def forward(self, inputs):

View File

@ -382,7 +382,7 @@ def record_current_batch_training_metrics(
infos = { infos = {
"tflops": tflops, "tflops": tflops,
"step": batch_count, "step": batch_count,
"loss": loss.item(), "loss": loss.item() - moe_loss.item(),
"moe_loss": moe_loss.item(), "moe_loss": moe_loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu, "tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr, "lr": lr,