diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 14fd451..211c610 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module +from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger @@ -391,7 +392,9 @@ class MOELayer(Base): self.overlap_degree = gpc.config.model.moe_overlap_degree # TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py - assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel" + assert (not self.use_tutel) or ( + self.ep_size == gpc.get_world_size(ParallelMode.DATA) and self.num_local_experts == 1 + ), "tutel only supports expert parallel size equals to data parallel size" def forward(self, *inputs: Tensor) -> Tensor: # Implement Algorithm 2 from GShard paper.