fix embedding convert bug

pull/544/head
Wenwen Qu 2023-12-15 13:39:33 +08:00
parent 014d74c20e
commit 934f60b753
1 changed files with 7 additions and 7 deletions

View File

@ -90,13 +90,13 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash):
moe_states[layer_i][expert_id][i][
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight"
] = w3s[i].clone()
if embed_split_hidden:
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1)
states[i]["model.tok_embeddings.weight"] = embeds[i].clone()
else:
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0)
states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone()
for i in range(tp_size):
if embed_split_hidden:
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1)
states[i]["model.tok_embeddings.weight"] = embeds[i].clone()
else:
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0)
states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone()
outputs = hf_state["lm_head.weight"].chunk(tp_size, 0)
for i in range(tp_size):