mirror of https://github.com/InternLM/InternLM
fix above codes:
*treat optim.zero_world_size and optim.zero_local_rank as list in model_checkpoint.py and test_model_checkpoint.py *add overlap and zero check for moe in args_sanity_check(.)pull/404/head
parent
c018e9216f
commit
e8fcbb1ad5
|
@ -319,6 +319,13 @@ def args_sanity_check():
|
||||||
if "moe_loss_coeff" not in gpc.config.loss:
|
if "moe_loss_coeff" not in gpc.config.loss:
|
||||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||||
|
|
||||||
|
# moe not support overlap and zero1.5 for now
|
||||||
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
|
assert (
|
||||||
|
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||||
|
), "not support overlap and moe at the same time"
|
||||||
|
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
config: Union[str, Path, Config, Dict],
|
config: Union[str, Path, Config, Dict],
|
||||||
|
|
|
@ -288,7 +288,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for j in range(len(param.size())):
|
for j in range(len(param.size())):
|
||||||
global_id = "_".join([global_id, str(param.size()[j])])
|
global_id = "_".join([global_id, str(param.size()[j])])
|
||||||
if self._overlap_sync_param:
|
if self._overlap_sync_param:
|
||||||
assert not hasattr(gpc.config.model, "num_experts")
|
|
||||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||||
else:
|
else:
|
||||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||||
|
|
|
@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path):
|
||||||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||||
|
|
||||||
states = optim.state_dict()
|
states = optim.state_dict()
|
||||||
if isinstance(optim, HybridZeroOptimizer):
|
if isinstance(optim, HybridZeroOptimizer):
|
||||||
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
if gpc.get_global_rank() < zero_size * tp_size * pp_size:
|
||||||
llm_save(os.path.join(state_path, fp), states)
|
llm_save(os.path.join(state_path, fp), states)
|
||||||
if "zero_devide_optim_plan" in states:
|
if "zero_devide_optim_plan" in states:
|
||||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||||
|
|
|
@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2):
|
||||||
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
||||||
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
||||||
re &= group_id_1 == group_id_2
|
re &= group_id_1 == group_id_2
|
||||||
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
|
if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||||
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
||||||
else:
|
else:
|
||||||
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
||||||
|
|
Loading…
Reference in New Issue