mirror of https://github.com/InternLM/InternLM
fix bugs in save/load moe checkpoint
parent
f6cadcafa2
commit
85f4d4af58
|
@ -257,7 +257,7 @@ def save_model_checkpoint(folder, model):
|
|||
llm_save(topo_fp, saved_obj=topo)
|
||||
|
||||
# try to save expert parameter to separate files if model have moe layer
|
||||
try_save_moe_checkpoint(folder, model)
|
||||
try_save_moe_checkpoint(folder, model, tp_rank, pp_rank)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
@ -306,7 +306,7 @@ def load_model_checkpoint(folder, model):
|
|||
"""
|
||||
|
||||
# try to load expert parameter to separate files if model have moe layer
|
||||
try_load_moe_checkpoint(folder, model, states)
|
||||
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
|
||||
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
if len(missing_k) != 0:
|
||||
|
@ -319,9 +319,10 @@ def load_model_checkpoint(folder, model):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def try_save_moe_checkpoint(folder, model):
|
||||
def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
|
||||
# Using layer_#_expert_# to save the model's expert state_dict,a hack.
|
||||
moe_layer_id = 0
|
||||
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
moe_layer_id = pp_rank * pipeline_stage_size
|
||||
for n_module, module in model.named_modules():
|
||||
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||
num_local_experts = module.num_local_experts
|
||||
|
@ -354,7 +355,7 @@ def try_save_moe_checkpoint(folder, model):
|
|||
# let save the moe parameters
|
||||
for global_expert_id, expert_state_dict in experts_state_dict.items():
|
||||
# save the moe parameters
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=expert_state_dict)
|
||||
moe_layer_id += 1
|
||||
|
@ -399,8 +400,9 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
llm_save(os.path.join(state_path, fp), states)
|
||||
|
||||
|
||||
def try_load_moe_checkpoint(folder, model, state_dict):
|
||||
moe_layer_id = 0
|
||||
def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
|
||||
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
moe_layer_id = pp_rank * pipeline_stage_size
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||
num_local_experts = module.num_local_experts
|
||||
|
@ -408,7 +410,7 @@ def try_load_moe_checkpoint(folder, model, state_dict):
|
|||
# loop all local_experts
|
||||
for local_expert_id in range(num_local_experts):
|
||||
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
||||
# Updating global -> local expert ids
|
||||
|
|
Loading…
Reference in New Issue