From 934f60b753188e2e3a3d6801dfc1e2ff48ff047b Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 15 Dec 2023 13:39:33 +0800 Subject: [PATCH] fix embedding convert bug --- tools/transformers/mixtral2llamamoe.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tools/transformers/mixtral2llamamoe.py b/tools/transformers/mixtral2llamamoe.py index 5eb17a8..a7ff098 100644 --- a/tools/transformers/mixtral2llamamoe.py +++ b/tools/transformers/mixtral2llamamoe.py @@ -90,13 +90,13 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash): moe_states[layer_i][expert_id][i][ f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight" ] = w3s[i].clone() - - if embed_split_hidden: - embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1) - states[i]["model.tok_embeddings.weight"] = embeds[i].clone() - else: - embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0) - states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone() + for i in range(tp_size): + if embed_split_hidden: + embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1) + states[i]["model.tok_embeddings.weight"] = embeds[i].clone() + else: + embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0) + states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone() outputs = hf_state["lm_head.weight"].chunk(tp_size, 0) for i in range(tp_size):