mirror of https://github.com/InternLM/InternLM
remove suffix for gate key
parent
196514d87f
commit
07c98c4a39
|
@ -612,7 +612,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
|
||||||
# get all moe parameters
|
# get all moe parameters
|
||||||
moe_state_dict = {}
|
moe_state_dict = {}
|
||||||
for n, p in module.state_dict().items():
|
for n, p in module.state_dict().items():
|
||||||
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
|
if "expert" in n and "moe_layer.gate" not in n:
|
||||||
moe_state_dict[n_module + "." + n] = p
|
moe_state_dict[n_module + "." + n] = p
|
||||||
moe_str_prefix = ".moe_layer.experts.experts."
|
moe_str_prefix = ".moe_layer.experts.experts."
|
||||||
# Reorder the moe name rank, so that each checkpoint only has one expert
|
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||||
|
@ -647,7 +647,7 @@ def get_non_moe_state_dict(full_state_dict):
|
||||||
Get the state dict of the non-moe layers
|
Get the state dict of the non-moe layers
|
||||||
"""
|
"""
|
||||||
for key in list(full_state_dict.keys()):
|
for key in list(full_state_dict.keys()):
|
||||||
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
|
if "expert" in key and "moe_layer.gate" not in key:
|
||||||
full_state_dict.pop(key)
|
full_state_dict.pop(key)
|
||||||
|
|
||||||
return full_state_dict
|
return full_state_dict
|
||||||
|
|
Loading…
Reference in New Issue