From 3d609d8e38d6d8100162d503be61a1d1ef5473d8 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Fri, 19 Jan 2024 17:29:40 +0800 Subject: [PATCH] remove --tp_size; rearrange wqkv --- tools/convert2llama.py | 155 +++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 75 deletions(-) diff --git a/tools/convert2llama.py b/tools/convert2llama.py index 252764f..1b6c25d 100644 --- a/tools/convert2llama.py +++ b/tools/convert2llama.py @@ -1,110 +1,116 @@ # Copyright (c) InternLM. All rights reserved. import argparse import os -from collections import defaultdict +import json import torch from einops import rearrange from tqdm import tqdm -from transformers import AutoConfig +from transformers import AutoConfig, LlamaTokenizer, LlamaConfig + +def save_conifg(config, tgt): + config_dict = config.to_dict() + unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"] + for k in unnecessary_keys: + config_dict.pop(k, None) + config_dict["attention_bias"] = config_dict.pop("bias") + config_dict["architectures"] = ["LlamaForCausalLM"] + llama_config = LlamaConfig(**config_dict) + llama_config.save_pretrained(tgt) -def split_wqkv(qkv, num_groups, q_per_kv, head_dim): - """Split wqkv into wq, wk, wv.""" - qkv = qkv.T - qkv = rearrange(qkv, "o (g n i) -> o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) - - q = qkv[..., :q_per_kv, :] - k = qkv[..., q_per_kv : q_per_kv + 1, :] - v = qkv[..., q_per_kv + 1 : q_per_kv + 2, :] - - q = rearrange(q, "o g n i -> o (g n i)", g=num_groups, n=q_per_kv, i=head_dim) - k = rearrange(k, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim) - v = rearrange(v, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim) - return q.T, k.T, v.T - - -def convert(src, tgt, tp_size): +def convert(src, tgt): """Convert InternLM2 huggingface checkpoints to Llama-style.""" - print("Loading origin checkpoints...") - hf_states = [] - hf_state_names = [] - remain_files = [] - for filename in tqdm(os.listdir(src)): - if not filename.endswith(".bin"): - remain_files.append(filename) - continue - hf_state_names.append(filename) - hf_states.append(torch.load(os.path.join(src, filename))) print("Convert InternLM2 huggingface checkpoints to Llama...") config = AutoConfig.from_pretrained(src, trust_remote_code=True) + assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA." - q_per_kv = config.num_attention_heads // config.num_key_value_heads head_dim = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads - num_heads_per_tp = num_heads // tp_size - num_groups = num_heads_per_tp // q_per_kv + num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + - for states in tqdm(hf_states): - tmp_states = defaultdict(defaultdict) + # load index json file + index_file = os.path.join(src, "pytorch_model.bin.index.json") + if os.path.exists(index_file): + with open(index_file) as fp: + index_dict = json.load(fp) + index_dict["weight_map"] = {} + else: + index_dict = None + + os.makedirs(tgt, exist_ok=True) + for filename in tqdm(os.listdir(src)): + if not filename.endswith(".bin"): + continue + states = torch.load(os.path.join(src, filename)) + llama_states = {} for k, v in states.copy().items(): if "wqkv" in k: - wqkvs = v.chunk(tp_size, 0) - for i in range(tp_size): - wq, wk, wv = split_wqkv(wqkvs[i], num_groups, q_per_kv, head_dim) - - _prefix = k.split("attention")[0] - wq_key = _prefix + "self_attn.q_proj.weight" - wk_key = _prefix + "self_attn.k_proj.weight" - wv_key = _prefix + "self_attn.v_proj.weight" - - tmp_states[wq_key][i] = wq.clone() - tmp_states[wk_key][i] = wk.clone() - tmp_states[wv_key][i] = wv.clone() + v = rearrange( + v, + "(h gs d) dim -> h gs d dim", + gs=2 + num_key_value_groups, + d=head_dim, + ) + wq, wk, wv = torch.split( + v, [num_key_value_groups, 1, 1], dim=1 + ) + wq = rearrange(wq, "h gs d dim -> (h gs d) dim") + wk = rearrange(wk, "h gs d dim -> (h gs d) dim") + wv = rearrange(wv, "h gs d dim -> (h gs d) dim") + _prefix = k.split("attention")[0] + wq_key = _prefix + "self_attn.q_proj.weight" + wk_key = _prefix + "self_attn.k_proj.weight" + wv_key = _prefix + "self_attn.v_proj.weight" + llama_states[wq_key] = wq.clone() + llama_states[wk_key] = wk.clone() + llama_states[wv_key] = wv.clone() elif "attention.wo" in k: new_k = k.replace("attention.wo", "self_attn.o_proj") - states[new_k] = v - del states[k] + llama_states[new_k] = v elif "feed_forward.w1" in k: new_k = k.replace("feed_forward.w1", "mlp.gate_proj") - states[new_k] = v - del states[k] + llama_states[new_k] = v elif "feed_forward.w2" in k: - new_k = k.replace("feed_forward.w2", "mlp.up_proj") - states[new_k] = v - del states[k] + new_k = k.replace("feed_forward.w2", "mlp.down_proj") + llama_states[new_k] = v elif "feed_forward.w3" in k: - new_k = k.replace("feed_forward.w3", "mlp.down_proj") - states[new_k] = v - del states[k] + new_k = k.replace("feed_forward.w3", "mlp.up_proj") + llama_states[new_k] = v elif "attention_norm" in k: new_k = k.replace("attention_norm", "input_layernorm") - states[new_k] = v - del states[k] + llama_states[new_k] = v elif "ffn_norm" in k: new_k = k.replace("ffn_norm", "post_attention_layernorm") - states[new_k] = v - del states[k] + llama_states[new_k] = v elif "tok_embeddings" in k: - states["model.embed_tokens.weight"] = v - del states[k] + llama_states["model.embed_tokens.weight"] = v elif "output" in k: - states["lm_head.weight"] = v - del states[k] + llama_states["lm_head.weight"] = v + else: + llama_states[k] = v - for k, v in tmp_states.items(): - states[k] = torch.cat(list(v.values()), dim=0) + if index_dict is not None: + for k in llama_states.keys(): + index_dict["weight_map"][k] = filename + print(f"Saving to {os.path.join(tgt, filename)}...", flush=True) + torch.save(llama_states, os.path.join(tgt, filename)) + del states - os.makedirs(tgt, exist_ok=True) - for i, states in enumerate(hf_states): - print(f"Saving to {os.path.join(tgt, hf_state_names[i])}...", flush=True) - torch.save(states, os.path.join(tgt, hf_state_names[i])) - for filename in remain_files: - print(f"Copying {filename}...", flush=True) - os.system(f"cp {os.path.join(src, filename)} {tgt}") + print("Saving config and tokenizer...") + # index.json + if index_dict is not None: + with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp: + json.dump(index_dict, fp, indent=2) + # tokenizer + tokenizer = LlamaTokenizer.from_pretrained(src) + tokenizer.init_kwargs.pop("auto_map", None) + tokenizer.save_pretrained(tgt) + # config + save_conifg(config, tgt) print("Done!") @@ -112,7 +118,6 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, help="Input folder") parser.add_argument("--tgt", type=str, help="Output folder") - parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel") args = parser.parse_args() @@ -122,4 +127,4 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - convert(args.src, args.tgt, args.tp_size) + convert(args.src, args.tgt)