InternLM/internlm/moe/base_moe.py

33 lines
990 B
Python

from typing import TYPE_CHECKING, Union
from torch import Tensor
from torch.nn import Module, ModuleList
from internlm.core.context import global_context as gpc
from internlm.moe.experts import Experts
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
class BaseMoELayer(Base):
"""
Base MoE Layer.
"""
def __init__(
self, gate: Module, experts: Union[Module, ModuleList], ep_group, ep_size: int, num_local_experts: int
) -> None:
super().__init__()
# for elastic expert paralle, experts may have multiple groups
expert_group_name = f"moe_ep_size_{ep_size}"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
self.gate = gate
self.experts = Experts(experts, num_local_experts, expert_group_name)
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts