mirror of https://github.com/InternLM/InternLM
fix more moe bugs in zero optimizer
parent
9ee57e6c8a
commit
48ed81d2e6
|
@ -172,7 +172,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._fp16_param_groups[group_id] = group_params
|
self._fp16_param_groups[group_id] = group_params
|
||||||
|
|
||||||
# assign parameters to ranks the params in the list are sorted
|
# assign parameters to ranks the params in the list are sorted
|
||||||
params_per_rank, no_params_ranks = self._partition_param_list(group_params)
|
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
|
||||||
self.param_group_no_params_ranks.append(no_params_ranks)
|
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
||||||
|
|
||||||
|
@ -260,15 +260,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def num_param_groups(self):
|
def num_param_groups(self):
|
||||||
return len(self._fp16_param_groups)
|
return len(self._fp16_param_groups)
|
||||||
|
|
||||||
def _partition_param_list(self, param_list):
|
def _get_real_dp_process_group(self, param_groups):
|
||||||
|
if "moe" in param_groups.keys() and param_groups["moe"]:
|
||||||
|
return ParallelMode.EXPERT_DATA
|
||||||
|
else:
|
||||||
|
return ParallelMode.DATA
|
||||||
|
|
||||||
|
def _partition_param_list(self, param_group):
|
||||||
no_params_ranks = []
|
no_params_ranks = []
|
||||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
||||||
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
||||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
||||||
|
param_list = param_group["params"]
|
||||||
|
|
||||||
if "moe" in param_list.keys() and param_list["moe"]:
|
if "moe" in param_group.keys() and param_group["moe"]:
|
||||||
# just add current params to params_per_rank[_zero_local_rank]
|
# just add current params to params_per_rank[_zero_local_rank]
|
||||||
params_per_rank[self._zero_local_rank] = list(param_list["params"])
|
params_per_rank[self._zero_local_rank] = list(param_list)
|
||||||
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
||||||
no_params_ranks = list(range(self._zero_world_size))
|
no_params_ranks = list(range(self._zero_world_size))
|
||||||
no_params_ranks.pop(self._zero_world_size)
|
no_params_ranks.pop(self._zero_world_size)
|
||||||
|
|
Loading…
Reference in New Issue