mirror of https://github.com/InternLM/InternLM
fix embedding convert bug
parent
014d74c20e
commit
934f60b753
|
@ -90,13 +90,13 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash):
|
||||||
moe_states[layer_i][expert_id][i][
|
moe_states[layer_i][expert_id][i][
|
||||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight"
|
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight"
|
||||||
] = w3s[i].clone()
|
] = w3s[i].clone()
|
||||||
|
for i in range(tp_size):
|
||||||
if embed_split_hidden:
|
if embed_split_hidden:
|
||||||
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1)
|
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1)
|
||||||
states[i]["model.tok_embeddings.weight"] = embeds[i].clone()
|
states[i]["model.tok_embeddings.weight"] = embeds[i].clone()
|
||||||
else:
|
else:
|
||||||
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0)
|
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0)
|
||||||
states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone()
|
states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone()
|
||||||
|
|
||||||
outputs = hf_state["lm_head.weight"].chunk(tp_size, 0)
|
outputs = hf_state["lm_head.weight"].chunk(tp_size, 0)
|
||||||
for i in range(tp_size):
|
for i in range(tp_size):
|
||||||
|
|
Loading…
Reference in New Issue