From 9ee57e6c8abe3c1aebb4caf4043bd45d93d29ebd Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 17 Aug 2023 15:31:27 +0800 Subject: [PATCH 1/2] fix moe bugs in zero optimizer --- .../solver/optimizer/hybrid_zero_optim.py | 53 +++++++++++-------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d59316c..526d28d 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -166,10 +166,6 @@ class HybridZeroOptimizer(BaseOptimizer): # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): - if "moe" in param_group.keys() and param_group["moe"]: - print("true", flush=True) - continue - group_params = param_group["params"] # add the fp16 params to fp16_param_groups for bookkeeping @@ -180,7 +176,10 @@ class HybridZeroOptimizer(BaseOptimizer): self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) - # store the mapping between param to rank each param should belong to only one rank + # store the mapping between param to rank each param should belong to only one rank. + # we can skip the moe param and do not keep them in _param_store to save memory + # (means we need to deal with moe param in a different way), but it will increase + # complexity and reduce code readablity. for rank, params in enumerate(params_per_rank): # check whether any rank is not assigned params. if len(params) != 0: @@ -267,26 +266,34 @@ class HybridZeroOptimizer(BaseOptimizer): numel_per_rank = [0 for _ in range(self._zero_world_size)] self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for i, param in enumerate(sorted_params): - global_id = str(i) - for j in range(len(param.size())): - global_id = "_".join([global_id, str(param.size()[j])]) + if "moe" in param_list.keys() and param_list["moe"]: + # just add current params to params_per_rank[_zero_local_rank] + params_per_rank[self._zero_local_rank] = list(param_list["params"]) + self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) + no_params_ranks = list(range(self._zero_world_size)) + no_params_ranks.pop(self._zero_world_size) - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) - numel_per_rank[rank_to_go] += param.numel() + else: + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for i, param in enumerate(sorted_params): + global_id = str(i) + for j in range(len(param.size())): + global_id = "_".join([global_id, str(param.size()[j])]) - # check whether any rank is not assigned to parameters. - for rank, params in enumerate(params_per_rank): - if len(params) == 0: - no_params_ranks.append(rank) + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) + numel_per_rank[rank_to_go] += param.numel() - if gpc.is_rank_for_log(): - logger.info( # pylint: disable=W1203 - f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" - ) + # check whether any rank is not assigned to parameters. + for rank, params in enumerate(params_per_rank): + if len(params) == 0: + no_params_ranks.append(rank) + + if gpc.is_rank_for_log(): + logger.info( # pylint: disable=W1203 + f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" + ) return params_per_rank, set(no_params_ranks) @@ -296,6 +303,7 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): param_group = self._fp16_param_groups[group_id] for param in param_group: + # we should not reduce the param in moe if param.requires_grad and not is_moe_param(param): reduce_rank = None @@ -496,6 +504,7 @@ class HybridZeroOptimizer(BaseOptimizer): if not self._overlap_communication: for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: + # we should not reduce the param in moe if param.grad is not None and not is_moe_param(param): self._store_and_try_reduce_grads_by_bucket(param) From 48ed81d2e6ebdeb4af0f49bfb3ba69051fb92f99 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 17 Aug 2023 15:44:36 +0800 Subject: [PATCH 2/2] fix more moe bugs in zero optimizer --- internlm/solver/optimizer/hybrid_zero_optim.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 526d28d..7c05843 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -172,7 +172,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._fp16_param_groups[group_id] = group_params # assign parameters to ranks the params in the list are sorted - params_per_rank, no_params_ranks = self._partition_param_list(group_params) + params_per_rank, no_params_ranks = self._partition_param_list(param_group) self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) @@ -260,15 +260,22 @@ class HybridZeroOptimizer(BaseOptimizer): def num_param_groups(self): return len(self._fp16_param_groups) - def _partition_param_list(self, param_list): + def _get_real_dp_process_group(self, param_groups): + if "moe" in param_groups.keys() and param_groups["moe"]: + return ParallelMode.EXPERT_DATA + else: + return ParallelMode.DATA + + def _partition_param_list(self, param_group): no_params_ranks = [] params_per_rank = [[] for _ in range(self._zero_world_size)] numel_per_rank = [0 for _ in range(self._zero_world_size)] self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) + param_list = param_group["params"] - if "moe" in param_list.keys() and param_list["moe"]: + if "moe" in param_group.keys() and param_group["moe"]: # just add current params to params_per_rank[_zero_local_rank] - params_per_rank[self._zero_local_rank] = list(param_list["params"]) + params_per_rank[self._zero_local_rank] = list(param_list) self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) no_params_ranks = list(range(self._zero_world_size)) no_params_ranks.pop(self._zero_world_size)