From 85f4d4af58fabd86cd5c792532a2a0c5024a6bc1 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 22 Sep 2023 17:54:33 +0800 Subject: [PATCH] fix bugs in save/load moe checkpoint --- internlm/utils/model_checkpoint.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index ee8481c..0377d58 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -257,7 +257,7 @@ def save_model_checkpoint(folder, model): llm_save(topo_fp, saved_obj=topo) # try to save expert parameter to separate files if model have moe layer - try_save_moe_checkpoint(folder, model) + try_save_moe_checkpoint(folder, model, tp_rank, pp_rank) torch.distributed.barrier() @@ -306,7 +306,7 @@ def load_model_checkpoint(folder, model): """ # try to load expert parameter to separate files if model have moe layer - try_load_moe_checkpoint(folder, model, states) + try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank) missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: @@ -319,9 +319,10 @@ def load_model_checkpoint(folder, model): torch.cuda.empty_cache() -def try_save_moe_checkpoint(folder, model): +def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank): # Using layer_#_expert_# to save the model's expert state_dict,a hack. - moe_layer_id = 0 + pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE) + moe_layer_id = pp_rank * pipeline_stage_size for n_module, module in model.named_modules(): if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: num_local_experts = module.num_local_experts @@ -354,7 +355,7 @@ def try_save_moe_checkpoint(folder, model): # let save the moe parameters for global_expert_id, expert_state_dict in experts_state_dict.items(): # save the moe parameters - fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt" + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=expert_state_dict) moe_layer_id += 1 @@ -399,8 +400,9 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) -def try_load_moe_checkpoint(folder, model, state_dict): - moe_layer_id = 0 +def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank): + pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE) + moe_layer_id = pp_rank * pipeline_stage_size for _, module in model.named_modules(): if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: num_local_experts = module.num_local_experts @@ -408,7 +410,7 @@ def try_load_moe_checkpoint(folder, model, state_dict): # loop all local_experts for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id - fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt" + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt" fp = os.path.join(folder, fn) expert_state_dict = llm_load(fp, map_location=get_current_device()) # Updating global -> local expert ids