mirror of https://github.com/InternLM/InternLM
change assert condition for tutel
parent
d20aa41d86
commit
d74ad7cca7
|
@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from internlm.core.context.parallel_context import ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
@ -391,7 +392,9 @@ class MOELayer(Base):
|
||||||
self.overlap_degree = gpc.config.model.moe_overlap_degree
|
self.overlap_degree = gpc.config.model.moe_overlap_degree
|
||||||
|
|
||||||
# TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py
|
# TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py
|
||||||
assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel"
|
assert (not self.use_tutel) or (
|
||||||
|
self.ep_size == gpc.get_world_size(ParallelMode.DATA) and self.num_local_experts == 1
|
||||||
|
), "tutel only supports expert parallel size equals to data parallel size"
|
||||||
|
|
||||||
def forward(self, *inputs: Tensor) -> Tensor:
|
def forward(self, *inputs: Tensor) -> Tensor:
|
||||||
# Implement Algorithm 2 from GShard paper.
|
# Implement Algorithm 2 from GShard paper.
|
||||||
|
|
Loading…
Reference in New Issue