diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d59316c..526d28d 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -166,10 +166,6 @@ class HybridZeroOptimizer(BaseOptimizer): # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): - if "moe" in param_group.keys() and param_group["moe"]: - print("true", flush=True) - continue - group_params = param_group["params"] # add the fp16 params to fp16_param_groups for bookkeeping @@ -180,7 +176,10 @@ class HybridZeroOptimizer(BaseOptimizer): self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) - # store the mapping between param to rank each param should belong to only one rank + # store the mapping between param to rank each param should belong to only one rank. + # we can skip the moe param and do not keep them in _param_store to save memory + # (means we need to deal with moe param in a different way), but it will increase + # complexity and reduce code readablity. for rank, params in enumerate(params_per_rank): # check whether any rank is not assigned params. if len(params) != 0: @@ -267,26 +266,34 @@ class HybridZeroOptimizer(BaseOptimizer): numel_per_rank = [0 for _ in range(self._zero_world_size)] self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for i, param in enumerate(sorted_params): - global_id = str(i) - for j in range(len(param.size())): - global_id = "_".join([global_id, str(param.size()[j])]) + if "moe" in param_list.keys() and param_list["moe"]: + # just add current params to params_per_rank[_zero_local_rank] + params_per_rank[self._zero_local_rank] = list(param_list["params"]) + self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) + no_params_ranks = list(range(self._zero_world_size)) + no_params_ranks.pop(self._zero_world_size) - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) - numel_per_rank[rank_to_go] += param.numel() + else: + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for i, param in enumerate(sorted_params): + global_id = str(i) + for j in range(len(param.size())): + global_id = "_".join([global_id, str(param.size()[j])]) - # check whether any rank is not assigned to parameters. - for rank, params in enumerate(params_per_rank): - if len(params) == 0: - no_params_ranks.append(rank) + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) + numel_per_rank[rank_to_go] += param.numel() - if gpc.is_rank_for_log(): - logger.info( # pylint: disable=W1203 - f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" - ) + # check whether any rank is not assigned to parameters. + for rank, params in enumerate(params_per_rank): + if len(params) == 0: + no_params_ranks.append(rank) + + if gpc.is_rank_for_log(): + logger.info( # pylint: disable=W1203 + f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" + ) return params_per_rank, set(no_params_ranks) @@ -296,6 +303,7 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): param_group = self._fp16_param_groups[group_id] for param in param_group: + # we should not reduce the param in moe if param.requires_grad and not is_moe_param(param): reduce_rank = None @@ -496,6 +504,7 @@ class HybridZeroOptimizer(BaseOptimizer): if not self._overlap_communication: for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: + # we should not reduce the param in moe if param.grad is not None and not is_moe_param(param): self._store_and_try_reduce_grads_by_bucket(param)