diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 63dcb63..4c7b741 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -5,7 +5,7 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.linear import FeedForward -from internlm.moe.sharded_moe import GShardMOELayer +from internlm.moe import GShardMOELayer from internlm.utils.logger import get_logger # global llm logger @@ -46,7 +46,7 @@ class MoE(torch.nn.Module): noisy_gate_policy: typing.Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, - using_default_moe: bool = True, + moe_type: str = None, use_residual=False, device=None, dtype=None, @@ -65,7 +65,7 @@ class MoE(torch.nn.Module): "Unsupported noisy_gate_policy: " + noisy_gate_policy ) - if using_default_moe: + if moe_type is None or moe_type == "GShard": self.moe_layer = GShardMOELayer( hidden_size, gpc.get_group(ParallelMode.EXPERT), diff --git a/internlm/moe/__init__.py b/internlm/moe/__init__.py index e69de29..fdda02c 100644 --- a/internlm/moe/__init__.py +++ b/internlm/moe/__init__.py @@ -0,0 +1,3 @@ +from internlm.moe.sharded_moe import GShardMOELayer + +__all__ = ["GShardMOELayer"] diff --git a/internlm/moe/base_moe.py b/internlm/moe/base_moe.py index 459771f..5f19e07 100644 --- a/internlm/moe/base_moe.py +++ b/internlm/moe/base_moe.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Union +import torch from torch import Tensor from torch.nn import Module, ModuleList @@ -30,3 +31,5 @@ class BaseMoELayer(Base): self.ep_group = ep_group self.ep_size = ep_size self.num_local_experts = num_local_experts + self.l_aux = torch.zeros(1, device=torch.cuda.current_device()) + self.exp_counts = None diff --git a/internlm/moe/experts.py b/internlm/moe/experts.py index be06686..df9ceb7 100644 --- a/internlm/moe/experts.py +++ b/internlm/moe/experts.py @@ -18,17 +18,14 @@ class Experts(torch.nn.Module): def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None): super().__init__() - # TODO: We can not deepcopy FeedForward since it contains a process_group in submodules - # self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)]) - if isinstance(experts, ModuleList): - self.experts = cast(ModuleList, experts) + self.wrapped_experts = cast(ModuleList, experts) else: - self.experts = ModuleList([experts]) + self.wrapped_experts = ModuleList([experts]) self.num_local_experts = num_local_experts # TODO: revisit allreduce for moe.gate... - for expert in self.experts: + for expert in self.wrapped_experts: # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) for _, param in expert.named_parameters(): param.is_expert = True @@ -37,7 +34,7 @@ class Experts(torch.nn.Module): def forward(self, inputs): chunks = inputs.chunk(self.num_local_experts, dim=1) expert_outputs = [] - for chunk, expert in zip(chunks, self.experts): + for chunk, expert in zip(chunks, self.wrapped_experts): out = expert(chunk) if isinstance(out, tuple): out = out[0] # Ignore the bias term for now diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 0bc7261..0c74227 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -614,7 +614,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank): for n, p in module.state_dict().items(): if "expert" in n and "moe_layer.gate" not in n: moe_state_dict[n_module + "." + n] = p - moe_str_prefix = ".moe_layer.experts.experts." + moe_str_prefix = ".moe_layer.experts.wrapped_experts." # Reorder the moe name rank, so that each checkpoint only has one expert experts_state_dict = defaultdict(dict) for key in list(moe_state_dict.keys()): @@ -696,7 +696,7 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank): fp = os.path.join(folder, fn) expert_state_dict = llm_load(fp, map_location=get_current_device()) # Updating global -> local expert ids - moe_str_prefix = ".moe_layer.experts.experts." + moe_str_prefix = ".moe_layer.experts.wrapped_experts." for key in list(expert_state_dict.keys()): local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}") expert_state_dict[local_key] = expert_state_dict.pop(key)