mirror of https://github.com/InternLM/InternLM
fix moe bugs in zero optimizer
parent
3bfaad895a
commit
754f1d961a
|
@ -166,21 +166,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# partition these param groups for data parallel training
|
# partition these param groups for data parallel training
|
||||||
# and add buffers to parameter store for future access
|
# and add buffers to parameter store for future access
|
||||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||||
if "moe" in param_group.keys() and param_group["moe"]:
|
|
||||||
print("true", flush=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
group_params = param_group["params"]
|
group_params = param_group["params"]
|
||||||
|
|
||||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||||
self._fp16_param_groups[group_id] = group_params
|
self._fp16_param_groups[group_id] = group_params
|
||||||
|
|
||||||
# assign parameters to ranks the params in the list are sorted
|
# assign parameters to ranks the params in the list are sorted
|
||||||
params_per_rank, no_params_ranks = self._partition_param_list(group_params)
|
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
|
||||||
self.param_group_no_params_ranks.append(no_params_ranks)
|
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
||||||
|
|
||||||
# store the mapping between param to rank each param should belong to only one rank
|
# store the mapping between param to rank each param should belong to only one rank.
|
||||||
|
# we can skip the moe param and do not keep them in _param_store to save memory
|
||||||
|
# (means we need to deal with moe param in a different way), but it will increase
|
||||||
|
# complexity and reduce code readablity.
|
||||||
for rank, params in enumerate(params_per_rank):
|
for rank, params in enumerate(params_per_rank):
|
||||||
# check whether any rank is not assigned params.
|
# check whether any rank is not assigned params.
|
||||||
if len(params) != 0:
|
if len(params) != 0:
|
||||||
|
@ -261,12 +260,27 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def num_param_groups(self):
|
def num_param_groups(self):
|
||||||
return len(self._fp16_param_groups)
|
return len(self._fp16_param_groups)
|
||||||
|
|
||||||
def _partition_param_list(self, param_list):
|
def _get_real_dp_process_group(self, param_groups):
|
||||||
|
if "moe" in param_groups.keys() and param_groups["moe"]:
|
||||||
|
return ParallelMode.EXPERT_DATA
|
||||||
|
else:
|
||||||
|
return ParallelMode.DATA
|
||||||
|
|
||||||
|
def _partition_param_list(self, param_group):
|
||||||
no_params_ranks = []
|
no_params_ranks = []
|
||||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
||||||
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
||||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
||||||
|
param_list = param_group["params"]
|
||||||
|
|
||||||
|
if "moe" in param_group.keys() and param_group["moe"]:
|
||||||
|
# just add current params to params_per_rank[_zero_local_rank]
|
||||||
|
params_per_rank[self._zero_local_rank] = list(param_list)
|
||||||
|
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
||||||
|
no_params_ranks = list(range(self._zero_world_size))
|
||||||
|
no_params_ranks.pop(self._zero_world_size)
|
||||||
|
|
||||||
|
else:
|
||||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||||
for i, param in enumerate(sorted_params):
|
for i, param in enumerate(sorted_params):
|
||||||
global_id = str(i)
|
global_id = str(i)
|
||||||
|
@ -296,6 +310,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
param_group = self._fp16_param_groups[group_id]
|
param_group = self._fp16_param_groups[group_id]
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
|
# we should not reduce the param in moe
|
||||||
if param.requires_grad and not is_moe_param(param):
|
if param.requires_grad and not is_moe_param(param):
|
||||||
reduce_rank = None
|
reduce_rank = None
|
||||||
|
|
||||||
|
@ -496,6 +511,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if not self._overlap_communication:
|
if not self._overlap_communication:
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
for param in self._fp16_param_groups[group_id]:
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
# we should not reduce the param in moe
|
||||||
if param.grad is not None and not is_moe_param(param):
|
if param.grad is not None and not is_moe_param(param):
|
||||||
self._store_and_try_reduce_grads_by_bucket(param)
|
self._store_and_try_reduce_grads_by_bucket(param)
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ def sync_tensor(tensor, parallel_mode):
|
||||||
|
|
||||||
|
|
||||||
# TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp
|
# TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp
|
||||||
def sync_model_param_within_ep(model):
|
def sync_model_param_with_ep(model):
|
||||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
Loading…
Reference in New Issue