refactor code

pull/567/head
Wenwen Qu 2024-01-08 14:33:19 +08:00
parent fdd60691d3
commit c3854f924a
4 changed files with 77 additions and 44 deletions

View File

@ -175,7 +175,7 @@ class PackedFlashBaseLayer1D(nn.Module):
hidden_size=hidden_size,
num_experts=num_experts,
ep_size=ep_size,
k=moe_gate_k,
topk=moe_gate_k,
capacity_factor=moe_capacity_factor,
eval_capacity_factor=moe_eval_capacity_factor,
min_capacity=moe_min_capacity,

View File

@ -5,8 +5,7 @@ import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import GShardMOELayer, TopKGate
from internlm.moe.sharded_moe import GShardMOELayer
from internlm.utils.logger import get_logger
# global llm logger
@ -40,7 +39,7 @@ class MoE(torch.nn.Module):
hidden_size,
num_experts=1,
ep_size=1,
k=1,
topk=1,
capacity_factor=1.0,
eval_capacity_factor=1.0,
min_capacity=4,
@ -66,43 +65,21 @@ class MoE(torch.nn.Module):
"Unsupported noisy_gate_policy: " + noisy_gate_policy
)
# for elastic expert paralle, experts may have multiple groups
expert_group_name = f"moe_ep_size_{self.ep_size}"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
experts = torch.nn.ModuleList(
[
FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
for _ in range(self.num_local_experts)
]
)
experts = Experts(experts, self.num_local_experts, expert_group_name)
if using_default_moe:
self.moe_layer = GShardMOELayer(
TopKGate(
hidden_size,
gpc.get_group(ParallelMode.EXPERT),
ep_size,
num_experts,
k,
topk,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts,
),
experts,
gpc.get_group(ParallelMode.EXPERT),
self.ep_size,
self.num_local_experts,
device,
dtype,
)
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence

View File

@ -1,7 +1,10 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from torch import Tensor
from torch.nn import Module
from torch.nn import Module, ModuleList
from internlm.core.context import global_context as gpc
from internlm.moe.experts import Experts
if TYPE_CHECKING:
Base = Module[Tensor]
@ -14,10 +17,16 @@ class BaseMoELayer(Base):
Base MoE Layer.
"""
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
def __init__(
self, gate: Module, experts: Union[Module, ModuleList], ep_group, ep_size: int, num_local_experts: int
) -> None:
super().__init__()
# for elastic expert paralle, experts may have multiple groups
expert_group_name = f"moe_ep_size_{ep_size}"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
self.gate = gate
self.experts = experts
self.experts = Experts(experts, num_local_experts, expert_group_name)
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts

View File

@ -12,6 +12,9 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
@ -379,8 +382,52 @@ class GShardMOELayer(BaseMoELayer):
expert network
"""
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
super().__init__(gate, experts, ep_group, ep_size, num_local_experts)
def __init__(
self,
hidden_size,
ep_group,
ep_size: int,
num_experts: int,
topk,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts,
device=None,
dtype=None,
) -> None:
super().__init__(
TopKGate(
hidden_size,
num_experts,
topk,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts,
),
torch.nn.ModuleList(
[
FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
for _ in range(num_experts // ep_size)
]
),
ep_group,
ep_size,
num_experts // ep_size,
)
self.time_falltoall = 0.0
self.time_salltoall = 0.0