diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 340efac..7f3e415 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -455,7 +455,6 @@ class ParallelContext(metaclass=SingletonMeta): self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size") - self._set_parallel_size_from_config(parallel_config, "expert", "expert_parallel_size") # the user should not set the data parallel size manually # instead, it should be calculated based on other parallel config @@ -467,13 +466,15 @@ class ParallelContext(metaclass=SingletonMeta): if self.zero1_parallel_size <= 0: self.zero1_parallel_size = self.data_parallel_size - # if not set expert_parallel_size in parallel config - if self.expert_parallel_size <= 0: - # by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller - # than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device - # has one expert. - self.expert_parallel_size = min(self.data_parallel_size, self.config.model.get("num_experts", 1)) - logger.warning(f"not set expert parallel size, set it as {self.expert_parallel_size}") + assert ( + self.data_parallel_size % self.config.model.get("num_experts", 1) == 0 + or self.config.model.get("num_experts", 1) % self.data_parallel_size == 0 + ), "can not place the experts evenly" + + # by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller + # than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device + # has one expert. + self.expert_parallel_size = min(self.data_parallel_size, self.config.model.get("num_experts", 1)) self.check_sanity() diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 7ec010c..660cc55 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -73,9 +73,6 @@ def args_sanity_check(): if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", 1) - if "expert" not in gpc.config.parallel: - gpc.config.parallel._add_item("expert", -1) - # processing the data config in gpc data = gpc.config.data