fix moe bugs in zero optimizer

pull/375/head
Wenwen Qu 2023-08-17 15:31:27 +08:00
parent 3bfaad895a
commit 754f1d961a
2 changed files with 41 additions and 25 deletions

View File

@ -166,21 +166,20 @@ class HybridZeroOptimizer(BaseOptimizer):
# partition these param groups for data parallel training # partition these param groups for data parallel training
# and add buffers to parameter store for future access # and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups): for group_id, param_group in enumerate(self.optim.param_groups):
if "moe" in param_group.keys() and param_group["moe"]:
print("true", flush=True)
continue
group_params = param_group["params"] group_params = param_group["params"]
# add the fp16 params to fp16_param_groups for bookkeeping # add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = group_params self._fp16_param_groups[group_id] = group_params
# assign parameters to ranks the params in the list are sorted # assign parameters to ranks the params in the list are sorted
params_per_rank, no_params_ranks = self._partition_param_list(group_params) params_per_rank, no_params_ranks = self._partition_param_list(param_group)
self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_no_params_ranks.append(no_params_ranks)
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
# store the mapping between param to rank each param should belong to only one rank # store the mapping between param to rank each param should belong to only one rank.
# we can skip the moe param and do not keep them in _param_store to save memory
# (means we need to deal with moe param in a different way), but it will increase
# complexity and reduce code readablity.
for rank, params in enumerate(params_per_rank): for rank, params in enumerate(params_per_rank):
# check whether any rank is not assigned params. # check whether any rank is not assigned params.
if len(params) != 0: if len(params) != 0:
@ -261,32 +260,47 @@ class HybridZeroOptimizer(BaseOptimizer):
def num_param_groups(self): def num_param_groups(self):
return len(self._fp16_param_groups) return len(self._fp16_param_groups)
def _partition_param_list(self, param_list): def _get_real_dp_process_group(self, param_groups):
if "moe" in param_groups.keys() and param_groups["moe"]:
return ParallelMode.EXPERT_DATA
else:
return ParallelMode.DATA
def _partition_param_list(self, param_group):
no_params_ranks = [] no_params_ranks = []
params_per_rank = [[] for _ in range(self._zero_world_size)] params_per_rank = [[] for _ in range(self._zero_world_size)]
numel_per_rank = [0 for _ in range(self._zero_world_size)] numel_per_rank = [0 for _ in range(self._zero_world_size)]
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
param_list = param_group["params"]
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) if "moe" in param_group.keys() and param_group["moe"]:
for i, param in enumerate(sorted_params): # just add current params to params_per_rank[_zero_local_rank]
global_id = str(i) params_per_rank[self._zero_local_rank] = list(param_list)
for j in range(len(param.size())): self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
global_id = "_".join([global_id, str(param.size()[j])]) no_params_ranks = list(range(self._zero_world_size))
no_params_ranks.pop(self._zero_world_size)
rank_to_go = numel_per_rank.index(min(numel_per_rank)) else:
params_per_rank[rank_to_go].append(param) sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) for i, param in enumerate(sorted_params):
numel_per_rank[rank_to_go] += param.numel() global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
# check whether any rank is not assigned to parameters. rank_to_go = numel_per_rank.index(min(numel_per_rank))
for rank, params in enumerate(params_per_rank): params_per_rank[rank_to_go].append(param)
if len(params) == 0: self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
no_params_ranks.append(rank) numel_per_rank[rank_to_go] += param.numel()
if gpc.is_rank_for_log(): # check whether any rank is not assigned to parameters.
logger.info( # pylint: disable=W1203 for rank, params in enumerate(params_per_rank):
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" if len(params) == 0:
) no_params_ranks.append(rank)
if gpc.is_rank_for_log():
logger.info( # pylint: disable=W1203
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
)
return params_per_rank, set(no_params_ranks) return params_per_rank, set(no_params_ranks)
@ -296,6 +310,7 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
param_group = self._fp16_param_groups[group_id] param_group = self._fp16_param_groups[group_id]
for param in param_group: for param in param_group:
# we should not reduce the param in moe
if param.requires_grad and not is_moe_param(param): if param.requires_grad and not is_moe_param(param):
reduce_rank = None reduce_rank = None
@ -496,6 +511,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if not self._overlap_communication: if not self._overlap_communication:
for group_id in range(len(self._fp16_param_groups)): for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]: for param in self._fp16_param_groups[group_id]:
# we should not reduce the param in moe
if param.grad is not None and not is_moe_param(param): if param.grad is not None and not is_moe_param(param):
self._store_and_try_reduce_grads_by_bucket(param) self._store_and_try_reduce_grads_by_bucket(param)

View File

@ -43,7 +43,7 @@ def sync_tensor(tensor, parallel_mode):
# TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp # TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp
def sync_model_param_within_ep(model): def sync_model_param_with_ep(model):
r"""Make sure data parameters are consistent during Data Parallel Mode. r"""Make sure data parameters are consistent during Data Parallel Mode.
Args: Args: