fix more moe bugs in zero optimizer

pull/375/head
Wenwen Qu 2023-08-17 15:44:36 +08:00
parent 9ee57e6c8a
commit 48ed81d2e6
1 changed files with 11 additions and 4 deletions

View File

@ -172,7 +172,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fp16_param_groups[group_id] = group_params
# assign parameters to ranks the params in the list are sorted
params_per_rank, no_params_ranks = self._partition_param_list(group_params)
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
self.param_group_no_params_ranks.append(no_params_ranks)
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
@ -260,15 +260,22 @@ class HybridZeroOptimizer(BaseOptimizer):
def num_param_groups(self):
return len(self._fp16_param_groups)
def _partition_param_list(self, param_list):
def _get_real_dp_process_group(self, param_groups):
if "moe" in param_groups.keys() and param_groups["moe"]:
return ParallelMode.EXPERT_DATA
else:
return ParallelMode.DATA
def _partition_param_list(self, param_group):
no_params_ranks = []
params_per_rank = [[] for _ in range(self._zero_world_size)]
numel_per_rank = [0 for _ in range(self._zero_world_size)]
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
param_list = param_group["params"]
if "moe" in param_list.keys() and param_list["moe"]:
if "moe" in param_group.keys() and param_group["moe"]:
# just add current params to params_per_rank[_zero_local_rank]
params_per_rank[self._zero_local_rank] = list(param_list["params"])
params_per_rank[self._zero_local_rank] = list(param_list)
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
no_params_ranks = list(range(self._zero_world_size))
no_params_ranks.pop(self._zero_world_size)