mirror of https://github.com/InternLM/InternLM
fix more moe bugs in zero optimizer
parent
9ee57e6c8a
commit
48ed81d2e6
|
@ -172,7 +172,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._fp16_param_groups[group_id] = group_params
|
||||
|
||||
# assign parameters to ranks the params in the list are sorted
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(group_params)
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
|
||||
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
||||
|
||||
|
@ -260,15 +260,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
def _get_real_dp_process_group(self, param_groups):
|
||||
if "moe" in param_groups.keys() and param_groups["moe"]:
|
||||
return ParallelMode.EXPERT_DATA
|
||||
else:
|
||||
return ParallelMode.DATA
|
||||
|
||||
def _partition_param_list(self, param_group):
|
||||
no_params_ranks = []
|
||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
||||
param_list = param_group["params"]
|
||||
|
||||
if "moe" in param_list.keys() and param_list["moe"]:
|
||||
if "moe" in param_group.keys() and param_group["moe"]:
|
||||
# just add current params to params_per_rank[_zero_local_rank]
|
||||
params_per_rank[self._zero_local_rank] = list(param_list["params"])
|
||||
params_per_rank[self._zero_local_rank] = list(param_list)
|
||||
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
||||
no_params_ranks = list(range(self._zero_world_size))
|
||||
no_params_ranks.pop(self._zero_world_size)
|
||||
|
|
Loading…
Reference in New Issue