From 4281caf30b3aba07e8cd38d29e12f5b701c99391 Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Fri, 19 Jan 2024 19:47:28 +0800 Subject: [PATCH] [Tool]: Support converting InternLM2 to Llama format (#627) Co-authored-by: x54-729 --- tools/README.md | 14 +++++ tools/convert2llama.py | 136 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 tools/README.md create mode 100644 tools/convert2llama.py diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 0000000..37bf7af --- /dev/null +++ b/tools/README.md @@ -0,0 +1,14 @@ +# InternLM2 tools + +## 1. Convert to LLaMA + +We offer the `convert2llama.py`, designed to seamlessly transform InternLM2 (HF format) into LLaMA (HF format). Here, HF refers to the format used by HuggingFace Transformers. + +### Usage +``` +python convert2llama.py --src /path/to/internlm2/ckpt --tgt /path/to/target/ckpt +``` + +### Note + +While the `convert2llama.py` tool is available, we still advise opting for InternLM2 when practical, chiefly due to its superior efficiency. InternLM2, which is adapted from LLaMA, streamlines the process by integrating the `Wq`, `Wk`, `Wv` weight matrices into a single matrix `Wqkv`. This integration leads to approximately a **5%** speed increase during training. Given the substantial costs associated with pre-training, this efficiency boost can result in significant savings. diff --git a/tools/convert2llama.py b/tools/convert2llama.py new file mode 100644 index 0000000..7e156da --- /dev/null +++ b/tools/convert2llama.py @@ -0,0 +1,136 @@ +# Copyright (c) InternLM. All rights reserved. +import argparse +import json +import os + +import torch +from einops import rearrange +from tqdm import tqdm +from transformers import AutoConfig, LlamaConfig, LlamaTokenizer + + +def save_conifg(config, tgt): + config_dict = config.to_dict() + unnecessary_keys = [ + "_name_or_path", + "auto_map", + "transformers_version", + "model_type", + "architectures", + "tokenizer_class", + "attn_implementation", + ] + for k in unnecessary_keys: + config_dict.pop(k, None) + config_dict["attention_bias"] = config_dict.pop("bias") + config_dict["architectures"] = ["LlamaForCausalLM"] + llama_config = LlamaConfig(**config_dict) + llama_config.save_pretrained(tgt) + + +def convert(src, tgt): + """Convert InternLM2 huggingface checkpoints to Llama-style.""" + + print("Convert InternLM2 huggingface checkpoints to Llama...") + + config = AutoConfig.from_pretrained(src, trust_remote_code=True) + assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA." + + head_dim = config.hidden_size // config.num_attention_heads + num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + + # load index json file + index_file = os.path.join(src, "pytorch_model.bin.index.json") + if os.path.exists(index_file): + with open(index_file) as fp: + index_dict = json.load(fp) + index_dict["weight_map"] = {} + else: + index_dict = None + + os.makedirs(tgt, exist_ok=True) + for filename in tqdm(os.listdir(src)): + if not filename.endswith(".bin"): + continue + states = torch.load(os.path.join(src, filename)) + llama_states = {} + for k, v in states.copy().items(): + if "wqkv" in k: + v = rearrange( + v, + "(h gs d) dim -> h gs d dim", + gs=2 + num_key_value_groups, + d=head_dim, + ) + wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1) + wq = rearrange(wq, "h gs d dim -> (h gs d) dim") + wk = rearrange(wk, "h gs d dim -> (h gs d) dim") + wv = rearrange(wv, "h gs d dim -> (h gs d) dim") + _prefix = k.split("attention")[0] + wq_key = _prefix + "self_attn.q_proj.weight" + wk_key = _prefix + "self_attn.k_proj.weight" + wv_key = _prefix + "self_attn.v_proj.weight" + llama_states[wq_key] = wq.clone() + llama_states[wk_key] = wk.clone() + llama_states[wv_key] = wv.clone() + + elif "attention.wo" in k: + new_k = k.replace("attention.wo", "self_attn.o_proj") + llama_states[new_k] = v + elif "feed_forward.w1" in k: + new_k = k.replace("feed_forward.w1", "mlp.gate_proj") + llama_states[new_k] = v + elif "feed_forward.w2" in k: + new_k = k.replace("feed_forward.w2", "mlp.down_proj") + llama_states[new_k] = v + elif "feed_forward.w3" in k: + new_k = k.replace("feed_forward.w3", "mlp.up_proj") + llama_states[new_k] = v + elif "attention_norm" in k: + new_k = k.replace("attention_norm", "input_layernorm") + llama_states[new_k] = v + elif "ffn_norm" in k: + new_k = k.replace("ffn_norm", "post_attention_layernorm") + llama_states[new_k] = v + elif "tok_embeddings" in k: + llama_states["model.embed_tokens.weight"] = v + elif "output" in k: + llama_states["lm_head.weight"] = v + else: + llama_states[k] = v + + if index_dict is not None: + for k in llama_states: + index_dict["weight_map"][k] = filename + print(f"Saving to {os.path.join(tgt, filename)}...", flush=True) + torch.save(llama_states, os.path.join(tgt, filename)) + del states + + print("Saving config and tokenizer...") + # index.json + if index_dict is not None: + with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp: + json.dump(index_dict, fp, indent=2) + # tokenizer + tokenizer = LlamaTokenizer.from_pretrained(src) + tokenizer.init_kwargs.pop("auto_map", None) + tokenizer.save_pretrained(tgt) + # config + save_conifg(config, tgt) + print("Done!") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, help="Input folder") + parser.add_argument("--tgt", type=str, help="Output folder") + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + + convert(args.src, args.tgt)