From d74ad7cca76fcbb0d6776f6733b5a79cf67789ac Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 17 Nov 2023 18:58:52 +0800 Subject: [PATCH] change assert condition for tutel --- internlm/moe/sharded_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 14fd451..211c610 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module +from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger @@ -391,7 +392,9 @@ class MOELayer(Base): self.overlap_degree = gpc.config.model.moe_overlap_degree # TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py - assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel" + assert (not self.use_tutel) or ( + self.ep_size == gpc.get_world_size(ParallelMode.DATA) and self.num_local_experts == 1 + ), "tutel only supports expert parallel size equals to data parallel size" def forward(self, *inputs: Tensor) -> Tensor: # Implement Algorithm 2 from GShard paper.