mirror of https://github.com/InternLM/InternLM
replace flashatten experts by feedforward experts
parent
cd6b28b073
commit
6cf0fec314
|
@ -2,10 +2,10 @@ import typing
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.linear import FeedForward
|
||||
from internlm.moe.experts import Experts
|
||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||
from internlm.utils.logger import get_logger
|
||||
|
@ -102,17 +102,12 @@ class MoE(torch.nn.Module):
|
|||
experts = torch.nn.ModuleList(
|
||||
[
|
||||
# TODO have trouble when use internlm.model.linear.FeedForward
|
||||
ParallelFusedMLP(
|
||||
FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -143,17 +138,12 @@ class MoE(torch.nn.Module):
|
|||
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
self.residual_mlp = ParallelFusedMLP(
|
||||
self.residual_mlp = FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -188,9 +178,7 @@ class MoE(torch.nn.Module):
|
|||
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
||||
|
||||
|
||||
def split_params_into_different_moe_groups_for_optimizer(
|
||||
param_groups: Tuple[Dict], max_group_size=178956971
|
||||
) -> Tuple[Dict]:
|
||||
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]:
|
||||
"""Split parameters into different MoE groups for optimizer
|
||||
Compatiable with muiltiple param groups, each should have a name
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class Experts(torch.nn.Module):
|
|||
for expert in self.experts:
|
||||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||
for _, param in expert.named_parameters():
|
||||
param.belong_expert = True
|
||||
param.is_expert = True
|
||||
param.group_name = expert_group_name
|
||||
|
||||
def forward(self, inputs):
|
||||
|
|
|
@ -382,7 +382,7 @@ def record_current_batch_training_metrics(
|
|||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"loss": loss.item() - moe_loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
|
|
Loading…
Reference in New Issue