From 6cf0fec314a39b90ff1a7061ddfb9d40ab0bde08 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 8 Sep 2023 18:04:57 +0800 Subject: [PATCH] replace flashatten experts by feedforward experts --- internlm/model/moe.py | 24 ++++++------------------ internlm/moe/experts.py | 2 +- internlm/train/training_internlm.py | 2 +- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index a978d31..631b85f 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -2,10 +2,10 @@ import typing from typing import Dict, Tuple import torch -from flash_attn.modules.mlp import ParallelFusedMLP from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.linear import FeedForward from internlm.moe.experts import Experts from internlm.moe.sharded_moe import MOELayer, TopKGate from internlm.utils.logger import get_logger @@ -102,17 +102,12 @@ class MoE(torch.nn.Module): experts = torch.nn.ModuleList( [ # TODO have trouble when use internlm.model.linear.FeedForward - ParallelFusedMLP( + FeedForward( hidden_size, int(hidden_size * gpc.config.model.mlp_ratio), out_features=hidden_size, - activation="gelu_approx", process_group=gpc.get_group(ParallelMode.TENSOR), - bias1=False, - bias2=False, - sequence_parallel=gpc.config.model.sequence_parallel, - checkpoint_lvl=0, - heuristic="auto", + bias=False, device=device, dtype=dtype, ) @@ -143,17 +138,12 @@ class MoE(torch.nn.Module): # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence self.use_residual = use_residual if use_residual: - self.residual_mlp = ParallelFusedMLP( + self.residual_mlp = FeedForward( hidden_size, int(hidden_size * gpc.config.model.mlp_ratio), out_features=hidden_size, - activation="gelu_approx", process_group=gpc.get_group(ParallelMode.TENSOR), - bias1=False, - bias2=False, - sequence_parallel=gpc.config.model.sequence_parallel, - checkpoint_lvl=0, - heuristic="auto", + bias=False, device=device, dtype=dtype, ) @@ -188,9 +178,7 @@ class MoE(torch.nn.Module): return output, self.moe_layer.l_aux, self.moe_layer.exp_counts -def split_params_into_different_moe_groups_for_optimizer( - param_groups: Tuple[Dict], max_group_size=178956971 -) -> Tuple[Dict]: +def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], max_group_size=None) -> Tuple[Dict]: """Split parameters into different MoE groups for optimizer Compatiable with muiltiple param groups, each should have a name diff --git a/internlm/moe/experts.py b/internlm/moe/experts.py index d57714e..ab93a0f 100644 --- a/internlm/moe/experts.py +++ b/internlm/moe/experts.py @@ -37,7 +37,7 @@ class Experts(torch.nn.Module): for expert in self.experts: # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) for _, param in expert.named_parameters(): - param.belong_expert = True + param.is_expert = True param.group_name = expert_group_name def forward(self, inputs): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index ab73558..54cea47 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -382,7 +382,7 @@ def record_current_batch_training_metrics( infos = { "tflops": tflops, "step": batch_count, - "loss": loss.item(), + "loss": loss.item() - moe_loss.item(), "moe_loss": moe_loss.item(), "tgs (tokens/gpu/second)": tk_per_gpu, "lr": lr,