mirror of https://github.com/InternLM/InternLM
replace flashatten experts by feedforward experts
parent
cd6b28b073
commit
6cf0fec314
|
@ -2,10 +2,10 @@ import typing
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.model.linear import FeedForward
|
||||||
from internlm.moe.experts import Experts
|
from internlm.moe.experts import Experts
|
||||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
@ -102,17 +102,12 @@ class MoE(torch.nn.Module):
|
||||||
experts = torch.nn.ModuleList(
|
experts = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
# TODO have trouble when use internlm.model.linear.FeedForward
|
# TODO have trouble when use internlm.model.linear.FeedForward
|
||||||
ParallelFusedMLP(
|
FeedForward(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
out_features=hidden_size,
|
out_features=hidden_size,
|
||||||
activation="gelu_approx",
|
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
bias1=False,
|
bias=False,
|
||||||
bias2=False,
|
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
|
||||||
checkpoint_lvl=0,
|
|
||||||
heuristic="auto",
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -143,17 +138,12 @@ class MoE(torch.nn.Module):
|
||||||
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
||||||
self.use_residual = use_residual
|
self.use_residual = use_residual
|
||||||
if use_residual:
|
if use_residual:
|
||||||
self.residual_mlp = ParallelFusedMLP(
|
self.residual_mlp = FeedForward(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
out_features=hidden_size,
|
out_features=hidden_size,
|
||||||
activation="gelu_approx",
|
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
bias1=False,
|
bias=False,
|
||||||
bias2=False,
|
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
|
||||||
checkpoint_lvl=0,
|
|
||||||
heuristic="auto",
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -188,9 +178,7 @@ class MoE(torch.nn.Module):
|
||||||
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
||||||
|
|
||||||
|
|
||||||
def split_params_into_different_moe_groups_for_optimizer(
|
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]:
|
||||||
param_groups: Tuple[Dict], max_group_size=178956971
|
|
||||||
) -> Tuple[Dict]:
|
|
||||||
"""Split parameters into different MoE groups for optimizer
|
"""Split parameters into different MoE groups for optimizer
|
||||||
Compatiable with muiltiple param groups, each should have a name
|
Compatiable with muiltiple param groups, each should have a name
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Experts(torch.nn.Module):
|
||||||
for expert in self.experts:
|
for expert in self.experts:
|
||||||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||||
for _, param in expert.named_parameters():
|
for _, param in expert.named_parameters():
|
||||||
param.belong_expert = True
|
param.is_expert = True
|
||||||
param.group_name = expert_group_name
|
param.group_name = expert_group_name
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
|
|
@ -382,7 +382,7 @@ def record_current_batch_training_metrics(
|
||||||
infos = {
|
infos = {
|
||||||
"tflops": tflops,
|
"tflops": tflops,
|
||||||
"step": batch_count,
|
"step": batch_count,
|
||||||
"loss": loss.item(),
|
"loss": loss.item() - moe_loss.item(),
|
||||||
"moe_loss": moe_loss.item(),
|
"moe_loss": moe_loss.item(),
|
||||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||||
"lr": lr,
|
"lr": lr,
|
||||||
|
|
Loading…
Reference in New Issue