create expert data group and broadcast moe parameter in expert data group

pull/375/head
Wenwen Qu 2023-08-21 11:40:39 +08:00
parent 08532dc20b
commit 12c614db94
3 changed files with 20 additions and 18 deletions

View File

@ -36,7 +36,7 @@ class Config(dict):
config (dict): The dict object to be wrapped.
"""
def __init__(self, config: dict = None):
def __init__(self, config: dict = None): # pylint: disable=W0231
if config is not None:
for k, v in config.items():
self._add_item(k, v)
@ -100,7 +100,7 @@ class Config(dict):
module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module() # pylint: disable=W4902,E1120
module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
# load into config
config = Config()
@ -438,14 +438,12 @@ class ParallelContext(metaclass=SingletonMeta):
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
self._set_parallel_size_from_config(parallel_config, "expert", "expert_parallel_size")
# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# TODO : data parallel size can be different with expert parallel size
self.expert_parallel_size = self.data_parallel_size
if self.zero1_parallel_size <= 0:
self.zero1_parallel_size = self.data_parallel_size
@ -470,7 +468,7 @@ class ParallelContext(metaclass=SingletonMeta):
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
if self.config.model.num_experts > 1:
initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args))
for initializer in initializers:
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):

View File

@ -422,8 +422,6 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
super().__init__(*args, **kwargs)
self.num_expert_parallel_group = self.world_size // self.expert_parallel_size
assert self.world_size % self.rank_num_per_expert_group == 0
def _get_expert_parallel_ranks(self):
"""
Create expert and data parallel groups
@ -434,17 +432,18 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
expert_data_parallel_group = [0,4], [2,6], [1,5], [3,7]
"""
data_parallel_groups = []
for i in range(self.model_parallel_size):
data_parallel_groups.append(list(range(i, self.world_size, self.model_parallel_size)))
model_parallel_size = self.pipeline_parallel_size * self.tensor_parallel_size
for i in range(model_parallel_size):
data_parallel_groups.append(list(range(i, self.world_size, model_parallel_size)))
expert_parallel_groups = []
expert_data_parallel_groups = []
for dp_ranks in range(self.num_expert_parallel_group):
for dp_ranks in data_parallel_groups:
# partition of expert parallel group, e.g. [0,2], [4,6]
part_ep_group = []
for i in range(0, self.data_parallel_size, self.expert_parallel_size):
part_ep_group.append(dp_ranks[i : i + self.expert_parallel_size])
expert_data_parallel_groups.extend(part_ep_group)
expert_parallel_groups.extend(part_ep_group)
for expert_dp_ranks in zip(*part_ep_group):
expert_data_parallel_groups.append(list(expert_dp_ranks))
@ -458,6 +457,11 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
list: [(local_rank, group_world_size, process_group, ranks_in_group, mode), ...]:
A length 2 list consists of expert parallelism's and expert data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
expert_parallel_groups, expert_data_parallel_groups = self._get_expert_parallel_ranks()
groups = []
@ -473,7 +477,7 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
groups.append((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EXPERT))
groups.append((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EXPERT))
for ranks in expert_data_parallel_groups:
group = dist.new_group(ranks)
@ -487,8 +491,8 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
groups.append(
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EXPERT_DATA)
)
groups.append(
(local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EXPERT_DATA)
)
return groups

View File

@ -58,7 +58,7 @@ from internlm.utils.model_checkpoint import (
from internlm.utils.parallel import (
get_parallel_log_file_name,
is_no_pp_or_last_stage,
sync_model_param,
sync_model_param_with_ep,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
@ -149,7 +149,7 @@ def initialize_model():
# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
# does not influence the model weights in optimizer be different with the origin parameters.
sync_model_param(model, parallel_mode=ParallelMode.DATA)
sync_model_param_with_ep(model)
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
# the same across tensor parallelism.