change assert condition for tutel

pull/506/head
Wenwen Qu 2023-11-17 18:58:52 +08:00 committed by Qu Wenwen
parent d20aa41d86
commit d74ad7cca7
1 changed files with 4 additions and 1 deletions

View File

@ -12,6 +12,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.utils.logger import get_logger
@ -391,7 +392,9 @@ class MOELayer(Base):
self.overlap_degree = gpc.config.model.moe_overlap_degree
# TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py
assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel"
assert (not self.use_tutel) or (
self.ep_size == gpc.get_world_size(ParallelMode.DATA) and self.num_local_experts == 1
), "tutel only supports expert parallel size equals to data parallel size"
def forward(self, *inputs: Tensor) -> Tensor:
# Implement Algorithm 2 from GShard paper.