From 48ed81d2e6ebdeb4af0f49bfb3ba69051fb92f99 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 17 Aug 2023 15:44:36 +0800 Subject: [PATCH] fix more moe bugs in zero optimizer --- internlm/solver/optimizer/hybrid_zero_optim.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 526d28d..7c05843 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -172,7 +172,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._fp16_param_groups[group_id] = group_params # assign parameters to ranks the params in the list are sorted - params_per_rank, no_params_ranks = self._partition_param_list(group_params) + params_per_rank, no_params_ranks = self._partition_param_list(param_group) self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) @@ -260,15 +260,22 @@ class HybridZeroOptimizer(BaseOptimizer): def num_param_groups(self): return len(self._fp16_param_groups) - def _partition_param_list(self, param_list): + def _get_real_dp_process_group(self, param_groups): + if "moe" in param_groups.keys() and param_groups["moe"]: + return ParallelMode.EXPERT_DATA + else: + return ParallelMode.DATA + + def _partition_param_list(self, param_group): no_params_ranks = [] params_per_rank = [[] for _ in range(self._zero_world_size)] numel_per_rank = [0 for _ in range(self._zero_world_size)] self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) + param_list = param_group["params"] - if "moe" in param_list.keys() and param_list["moe"]: + if "moe" in param_group.keys() and param_group["moe"]: # just add current params to params_per_rank[_zero_local_rank] - params_per_rank[self._zero_local_rank] = list(param_list["params"]) + params_per_rank[self._zero_local_rank] = list(param_list) self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) no_params_ranks = list(range(self._zero_world_size)) no_params_ranks.pop(self._zero_world_size)