fix bugs in save/load moe checkpoint

pull/182/head
Wenwen Qu 2023-09-22 17:54:33 +08:00
parent f6cadcafa2
commit 85f4d4af58
1 changed files with 10 additions and 8 deletions

View File

@ -257,7 +257,7 @@ def save_model_checkpoint(folder, model):
llm_save(topo_fp, saved_obj=topo)
# try to save expert parameter to separate files if model have moe layer
try_save_moe_checkpoint(folder, model)
try_save_moe_checkpoint(folder, model, tp_rank, pp_rank)
torch.distributed.barrier()
@ -306,7 +306,7 @@ def load_model_checkpoint(folder, model):
"""
# try to load expert parameter to separate files if model have moe layer
try_load_moe_checkpoint(folder, model, states)
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
if len(missing_k) != 0:
@ -319,9 +319,10 @@ def load_model_checkpoint(folder, model):
torch.cuda.empty_cache()
def try_save_moe_checkpoint(folder, model):
def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
# Using layer_#_expert_# to save the model's expert state_dicta hack.
moe_layer_id = 0
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pipeline_stage_size
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
@ -354,7 +355,7 @@ def try_save_moe_checkpoint(folder, model):
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=expert_state_dict)
moe_layer_id += 1
@ -399,8 +400,9 @@ def save_optimizer_checkpoint(optim, state_path):
llm_save(os.path.join(state_path, fp), states)
def try_load_moe_checkpoint(folder, model, state_dict):
moe_layer_id = 0
def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pipeline_stage_size
for _, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
@ -408,7 +410,7 @@ def try_load_moe_checkpoint(folder, model, state_dict):
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids