replace flashatten experts by feedforward experts

pull/182/head
Wenwen Qu 2023-09-08 18:04:57 +08:00
parent cd6b28b073
commit 6cf0fec314
3 changed files with 8 additions and 20 deletions

View File

@ -2,10 +2,10 @@ import typing
from typing import Dict, Tuple
import torch
from flash_attn.modules.mlp import ParallelFusedMLP
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.utils.logger import get_logger
@ -102,17 +102,12 @@ class MoE(torch.nn.Module):
experts = torch.nn.ModuleList(
[
# TODO have trouble when use internlm.model.linear.FeedForward
ParallelFusedMLP(
FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
bias=False,
device=device,
dtype=dtype,
)
@ -143,17 +138,12 @@ class MoE(torch.nn.Module):
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
self.use_residual = use_residual
if use_residual:
self.residual_mlp = ParallelFusedMLP(
self.residual_mlp = FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
bias=False,
device=device,
dtype=dtype,
)
@ -188,9 +178,7 @@ class MoE(torch.nn.Module):
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
def split_params_into_different_moe_groups_for_optimizer(
param_groups: Tuple[Dict], max_group_size=178956971
) -> Tuple[Dict]:
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name

View File

@ -37,7 +37,7 @@ class Experts(torch.nn.Module):
for expert in self.experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for _, param in expert.named_parameters():
param.belong_expert = True
param.is_expert = True
param.group_name = expert_group_name
def forward(self, inputs):

View File

@ -382,7 +382,7 @@ def record_current_batch_training_metrics(
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"loss": loss.item() - moe_loss.item(),
"moe_loss": moe_loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,