From 754f1d961a9b9d4b172c9a2e27386fe69a810d81 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 17 Aug 2023 15:31:27 +0800 Subject: [PATCH] fix moe bugs in zero optimizer --- .../solver/optimizer/hybrid_zero_optim.py | 64 ++++++++++++------- internlm/utils/parallel.py | 2 +- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d59316c..7c05843 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -166,21 +166,20 @@ class HybridZeroOptimizer(BaseOptimizer): # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): - if "moe" in param_group.keys() and param_group["moe"]: - print("true", flush=True) - continue - group_params = param_group["params"] # add the fp16 params to fp16_param_groups for bookkeeping self._fp16_param_groups[group_id] = group_params # assign parameters to ranks the params in the list are sorted - params_per_rank, no_params_ranks = self._partition_param_list(group_params) + params_per_rank, no_params_ranks = self._partition_param_list(param_group) self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) - # store the mapping between param to rank each param should belong to only one rank + # store the mapping between param to rank each param should belong to only one rank. + # we can skip the moe param and do not keep them in _param_store to save memory + # (means we need to deal with moe param in a different way), but it will increase + # complexity and reduce code readablity. for rank, params in enumerate(params_per_rank): # check whether any rank is not assigned params. if len(params) != 0: @@ -261,32 +260,47 @@ class HybridZeroOptimizer(BaseOptimizer): def num_param_groups(self): return len(self._fp16_param_groups) - def _partition_param_list(self, param_list): + def _get_real_dp_process_group(self, param_groups): + if "moe" in param_groups.keys() and param_groups["moe"]: + return ParallelMode.EXPERT_DATA + else: + return ParallelMode.DATA + + def _partition_param_list(self, param_group): no_params_ranks = [] params_per_rank = [[] for _ in range(self._zero_world_size)] numel_per_rank = [0 for _ in range(self._zero_world_size)] self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) + param_list = param_group["params"] - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for i, param in enumerate(sorted_params): - global_id = str(i) - for j in range(len(param.size())): - global_id = "_".join([global_id, str(param.size()[j])]) + if "moe" in param_group.keys() and param_group["moe"]: + # just add current params to params_per_rank[_zero_local_rank] + params_per_rank[self._zero_local_rank] = list(param_list) + self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) + no_params_ranks = list(range(self._zero_world_size)) + no_params_ranks.pop(self._zero_world_size) - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) - numel_per_rank[rank_to_go] += param.numel() + else: + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for i, param in enumerate(sorted_params): + global_id = str(i) + for j in range(len(param.size())): + global_id = "_".join([global_id, str(param.size()[j])]) - # check whether any rank is not assigned to parameters. - for rank, params in enumerate(params_per_rank): - if len(params) == 0: - no_params_ranks.append(rank) + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) + numel_per_rank[rank_to_go] += param.numel() - if gpc.is_rank_for_log(): - logger.info( # pylint: disable=W1203 - f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" - ) + # check whether any rank is not assigned to parameters. + for rank, params in enumerate(params_per_rank): + if len(params) == 0: + no_params_ranks.append(rank) + + if gpc.is_rank_for_log(): + logger.info( # pylint: disable=W1203 + f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" + ) return params_per_rank, set(no_params_ranks) @@ -296,6 +310,7 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): param_group = self._fp16_param_groups[group_id] for param in param_group: + # we should not reduce the param in moe if param.requires_grad and not is_moe_param(param): reduce_rank = None @@ -496,6 +511,7 @@ class HybridZeroOptimizer(BaseOptimizer): if not self._overlap_communication: for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: + # we should not reduce the param in moe if param.grad is not None and not is_moe_param(param): self._store_and_try_reduce_grads_by_bucket(param) diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 5a9e4c6..5df51d1 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -43,7 +43,7 @@ def sync_tensor(tensor, parallel_mode): # TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp -def sync_model_param_within_ep(model): +def sync_model_param_with_ep(model): r"""Make sure data parameters are consistent during Data Parallel Mode. Args: