diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 660cc55..985e57f 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -319,6 +319,13 @@ def args_sanity_check(): if "moe_loss_coeff" not in gpc.config.loss: gpc.config.loss._add_item("moe_loss_coeff", 1.0) + # moe not support overlap and zero1.5 for now + if hasattr(gpc.config.model, "num_experts"): + assert ( + not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param + ), "not support overlap and moe at the same time" + assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this" + def launch( config: Union[str, Path, Config, Dict], diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 3ca18e4..3b06e9c 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -288,7 +288,6 @@ class HybridZeroOptimizer(BaseOptimizer): for j in range(len(param.size())): global_id = "_".join([global_id, str(param.size()[j])]) if self._overlap_sync_param: - assert not hasattr(gpc.config.model, "num_experts") rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) else: rank_to_go = numel_per_rank.index(min(numel_per_rank)) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 566bb0f..00e7436 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path): zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + zero_size = gpc.get_world_size(ParallelMode.ZERO1) tp_size = gpc.get_world_size(ParallelMode.TENSOR) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" states = optim.state_dict() if isinstance(optim, HybridZeroOptimizer): - if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size: + if gpc.get_global_rank() < zero_size * tp_size * pp_size: llm_save(os.path.join(state_path, fp), states) if "zero_devide_optim_plan" in states: params_per_rank_id_dict = states.pop("zero_devide_optim_plan") diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index b12ccad..0804455 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2): fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2): re &= group_id_1 == group_id_2 - if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]: + if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]: re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2]) else: for group1, group2 in zip(optim1.param_groups, optim2.param_groups):