mirror of https://github.com/InternLM/InternLM
do not set ep size from config
parent
c4c43bf157
commit
6195ea724f
|
@ -455,7 +455,6 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
|
||||
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
|
||||
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
|
||||
self._set_parallel_size_from_config(parallel_config, "expert", "expert_parallel_size")
|
||||
|
||||
# the user should not set the data parallel size manually
|
||||
# instead, it should be calculated based on other parallel config
|
||||
|
@ -467,13 +466,15 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
if self.zero1_parallel_size <= 0:
|
||||
self.zero1_parallel_size = self.data_parallel_size
|
||||
|
||||
# if not set expert_parallel_size in parallel config
|
||||
if self.expert_parallel_size <= 0:
|
||||
# by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller
|
||||
# than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device
|
||||
# has one expert.
|
||||
self.expert_parallel_size = min(self.data_parallel_size, self.config.model.get("num_experts", 1))
|
||||
logger.warning(f"not set expert parallel size, set it as {self.expert_parallel_size}")
|
||||
assert (
|
||||
self.data_parallel_size % self.config.model.get("num_experts", 1) == 0
|
||||
or self.config.model.get("num_experts", 1) % self.data_parallel_size == 0
|
||||
), "can not place the experts evenly"
|
||||
|
||||
# by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller
|
||||
# than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device
|
||||
# has one expert.
|
||||
self.expert_parallel_size = min(self.data_parallel_size, self.config.model.get("num_experts", 1))
|
||||
|
||||
self.check_sanity()
|
||||
|
||||
|
|
|
@ -73,9 +73,6 @@ def args_sanity_check():
|
|||
if "tensor" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("tensor", 1)
|
||||
|
||||
if "expert" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("expert", -1)
|
||||
|
||||
# processing the data config in gpc
|
||||
data = gpc.config.data
|
||||
|
||||
|
|
Loading…
Reference in New Issue