diff --git a/tools/transformers/mixtral2llamamoe.py b/tools/transformers/mixtral2llamamoe.py index a7ff098..1a44b41 100644 --- a/tools/transformers/mixtral2llamamoe.py +++ b/tools/transformers/mixtral2llamamoe.py @@ -6,7 +6,7 @@ from tqdm import tqdm from transformers import AutoConfig -def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash): +def revert(src, tgt, tp_size, pp_size, embed_split_hidden, adapt_hf, use_flash): hf_state = {} print("Loading HF checkpoints...") for filename in tqdm(os.listdir(src)): @@ -33,75 +33,81 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash): return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) # revert - states = [{} for _ in range(tp_size)] + # no-moe is stored according to tp and pp ranks + states = [[{} for _ in range(tp_size)] for _ in range(pp_size)] + # moe is stored according to layer id, expert id and tp rank moe_states = [ [[{} for _ in range(tp_size)] for _ in range(config.num_experts)] for _ in range(config.num_hidden_layers) ] # layers - for layer_i in tqdm(range(config.num_hidden_layers)): - # no-moe - for i in range(tp_size): - states[i][f"model.layers.{layer_i}.attention_norm.weight"] = hf_state[ - f"model.layers.{layer_i}.input_layernorm.weight" - ].clone() - states[i][f"model.layers.{layer_i}.ffn_norm.weight"] = hf_state[ - f"model.layers.{layer_i}.post_attention_layernorm.weight" - ].clone() - states[i][f"model.layers.{layer_i}.feed_forward.moe_layer.gate.wg.weight"] = hf_state[ - f"model.layers.{layer_i}.mlp.gate.weight" - ].clone() - - # mha - wqs = ( - permute(hf_state[f"model.layers.{layer_i}.self_attn.q_proj.weight"]) - # .view(-1, dims_per_head, dim) - .chunk(tp_size, 0) - ) - wks = ( - permute(hf_state[f"model.layers.{layer_i}.self_attn.k_proj.weight"], n_kv_heads, -1, dim) - # .view(-1, dims_per_head, dim) - .chunk(tp_size, 0) - ) - wvs = ( - hf_state[f"model.layers.{layer_i}.self_attn.v_proj.weight"] - # .view(-1, dims_per_head, dim) - .chunk(tp_size, 0) - ) - wos = hf_state[f"model.layers.{layer_i}.self_attn.o_proj.weight"].chunk(tp_size, 1) - for i in range(tp_size): - states[i][f"model.layers.{layer_i}.attention.wq.weight"] = wqs[i].reshape(-1, dim).clone() - states[i][f"model.layers.{layer_i}.attention.wk.weight"] = wks[i].reshape(-1, dim).clone() - states[i][f"model.layers.{layer_i}.attention.wv.weight"] = wvs[i].reshape(-1, dim).clone() - states[i][f"model.layers.{layer_i}.attention.wo.weight"] = wos[i].clone() - - # moe - for expert_id in range(config.num_experts): - w1s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w1.weight"].chunk(tp_size, 0) - w2s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w3.weight"].chunk(tp_size, 0) - w3s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w2.weight"].chunk(tp_size, 1) + assert config.num_hidden_layers % pp_size == 0 + num_layer_per_stage = config.num_hidden_layers // pp_size + for p_i in range(pp_size): + for layer_i in tqdm(range(num_layer_per_stage)): + # no-moe for i in range(tp_size): - moe_states[layer_i][expert_id][i][ - f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w1.weight" - ] = w1s[i].clone() - moe_states[layer_i][expert_id][i][ - f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w2.weight" - ] = w2s[i].clone() - moe_states[layer_i][expert_id][i][ - f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight" - ] = w3s[i].clone() + states[p_i][i][f"model.layers.{layer_i}.attention_norm.weight"] = hf_state[ + f"model.layers.{layer_i}.input_layernorm.weight" + ].clone() + states[p_i][i][f"model.layers.{layer_i}.ffn_norm.weight"] = hf_state[ + f"model.layers.{layer_i}.post_attention_layernorm.weight" + ].clone() + states[p_i][i][f"model.layers.{layer_i}.feed_forward.moe_layer.gate.wg.weight"] = hf_state[ + f"model.layers.{layer_i}.mlp.gate.weight" + ].clone() + + # mha + wqs = ( + permute(hf_state[f"model.layers.{layer_i}.self_attn.q_proj.weight"]) + # .view(-1, dims_per_head, dim) + .chunk(tp_size, 0) + ) + wks = ( + permute(hf_state[f"model.layers.{layer_i}.self_attn.k_proj.weight"], n_kv_heads, -1, dim) + # .view(-1, dims_per_head, dim) + .chunk(tp_size, 0) + ) + wvs = ( + hf_state[f"model.layers.{layer_i}.self_attn.v_proj.weight"] + # .view(-1, dims_per_head, dim) + .chunk(tp_size, 0) + ) + wos = hf_state[f"model.layers.{layer_i}.self_attn.o_proj.weight"].chunk(tp_size, 1) + for i in range(tp_size): + states[p_i][i][f"model.layers.{layer_i}.attention.wq.weight"] = wqs[i].reshape(-1, dim).clone() + states[p_i][i][f"model.layers.{layer_i}.attention.wk.weight"] = wks[i].reshape(-1, dim).clone() + states[p_i][i][f"model.layers.{layer_i}.attention.wv.weight"] = wvs[i].reshape(-1, dim).clone() + states[p_i][i][f"model.layers.{layer_i}.attention.wo.weight"] = wos[i].clone() + + # moe + global_layer_i = p_i * num_layer_per_stage + layer_i + for expert_id in range(config.num_experts): + w1s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w1.weight"].chunk(tp_size, 0) + w2s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w3.weight"].chunk(tp_size, 0) + w3s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w2.weight"].chunk(tp_size, 1) + for i in range(tp_size): + moe_states[global_layer_i][expert_id][i][ + f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w1.weight" + ] = w1s[i].clone() + moe_states[global_layer_i][expert_id][i][ + f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w2.weight" + ] = w2s[i].clone() + moe_states[global_layer_i][expert_id][i][ + f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight" + ] = w3s[i].clone() for i in range(tp_size): if embed_split_hidden: embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1) - states[i]["model.tok_embeddings.weight"] = embeds[i].clone() + states[0][i]["model.tok_embeddings.weight"] = embeds[i].clone() else: embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0) - states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone() + states[0][i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone() outputs = hf_state["lm_head.weight"].chunk(tp_size, 0) for i in range(tp_size): - states[i]["model.norm.weight"] = hf_state["model.norm.weight"].clone() - states[i]["model.output.weight"] = outputs[i].clone() + states[pp_size - 1][i]["model.norm.weight"] = hf_state["model.norm.weight"].clone() + states[pp_size - 1][i]["model.output.weight"] = outputs[i].clone() mlp_ratio = round((config.intermediate_size - 255) / config.hidden_size + 0.01, 2) if "rotary" in config.to_dict(): @@ -134,8 +140,9 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash): # save os.makedirs(tgt, exist_ok=True) print(f"Saving to {tgt}...") - for tp in tqdm(range(tp_size)): - torch.save(states[tp], os.path.join(tgt, f"model_tp{tp}_pp0.pt")) + for pp in range(pp_size): + for tp in tqdm(range(tp_size)): + torch.save(states[pp][tp], os.path.join(tgt, f"model_tp{tp}_pp{pp}.pt")) for moe_layer_id in range(config.num_hidden_layers): for expert_id in range(config.num_experts): for tp in tqdm(range(tp_size)): @@ -151,6 +158,7 @@ def print_args(args): print(f"Source Path: {args.src}") print(f"Target Path: {args.tgt}") print(f"TP Size: {args.tp_size}") + print(f"PP Size: {args.pp_size}") print(f"Embeb Split Hidden: {args.embed_split}") print(f"Adapt HF: {args.adapt_hf}") print(f"Use Flash Attn: {args.use_flash}") @@ -163,6 +171,7 @@ def parse_args(): parser.add_argument("--src", type=str, help="Input folder") parser.add_argument("--tgt", type=str, help="Output folder") parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel") + parser.add_argument("--pp_size", type=int, help="world_size of pipeline parallel") parser.add_argument("--embed_split", action="store_true", help="embed_split_hidden of InternLM") parser.add_argument("--adapt_hf", action="store_true", help="adapt_hf of InternLM") parser.add_argument("--use_flash", action="store_true", help="use_flash_attn of InternLM") @@ -178,4 +187,4 @@ if __name__ == "__main__": args = parse_args() print_args(args) - revert(args.src, args.tgt, args.tp_size, args.embed_split, args.adapt_hf, args.use_flash) + revert(args.src, args.tgt, args.tp_size, args.pp_size, args.embed_split, args.adapt_hf, args.use_flash)