# Copyright (c) InternLM. All rights reserved. import argparse import json import os import torch from einops import rearrange from tqdm import tqdm from transformers import AutoConfig, LlamaConfig, LlamaTokenizer def save_conifg(config, tgt): config_dict = config.to_dict() unnecessary_keys = [ "_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation", ] for k in unnecessary_keys: config_dict.pop(k, None) config_dict["attention_bias"] = config_dict.pop("bias") config_dict["architectures"] = ["LlamaForCausalLM"] llama_config = LlamaConfig(**config_dict) llama_config.save_pretrained(tgt) def convert(src, tgt): """Convert InternLM2 huggingface checkpoints to Llama-style.""" print("Convert InternLM2 huggingface checkpoints to Llama...") config = AutoConfig.from_pretrained(src, trust_remote_code=True) assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA." head_dim = config.hidden_size // config.num_attention_heads num_key_value_groups = config.num_attention_heads // config.num_key_value_heads # load index json file index_file = os.path.join(src, "pytorch_model.bin.index.json") if os.path.exists(index_file): with open(index_file) as fp: index_dict = json.load(fp) index_dict["weight_map"] = {} else: index_dict = None os.makedirs(tgt, exist_ok=True) for filename in tqdm(os.listdir(src)): if not filename.endswith(".bin"): continue states = torch.load(os.path.join(src, filename)) llama_states = {} for k, v in states.copy().items(): if "wqkv" in k: v = rearrange( v, "(h gs d) dim -> h gs d dim", gs=2 + num_key_value_groups, d=head_dim, ) wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1) wq = rearrange(wq, "h gs d dim -> (h gs d) dim") wk = rearrange(wk, "h gs d dim -> (h gs d) dim") wv = rearrange(wv, "h gs d dim -> (h gs d) dim") _prefix = k.split("attention")[0] wq_key = _prefix + "self_attn.q_proj.weight" wk_key = _prefix + "self_attn.k_proj.weight" wv_key = _prefix + "self_attn.v_proj.weight" llama_states[wq_key] = wq.clone() llama_states[wk_key] = wk.clone() llama_states[wv_key] = wv.clone() elif "attention.wo" in k: new_k = k.replace("attention.wo", "self_attn.o_proj") llama_states[new_k] = v elif "feed_forward.w1" in k: new_k = k.replace("feed_forward.w1", "mlp.gate_proj") llama_states[new_k] = v elif "feed_forward.w2" in k: new_k = k.replace("feed_forward.w2", "mlp.down_proj") llama_states[new_k] = v elif "feed_forward.w3" in k: new_k = k.replace("feed_forward.w3", "mlp.up_proj") llama_states[new_k] = v elif "attention_norm" in k: new_k = k.replace("attention_norm", "input_layernorm") llama_states[new_k] = v elif "ffn_norm" in k: new_k = k.replace("ffn_norm", "post_attention_layernorm") llama_states[new_k] = v elif "tok_embeddings" in k: llama_states["model.embed_tokens.weight"] = v elif "output" in k: llama_states["lm_head.weight"] = v else: llama_states[k] = v if index_dict is not None: for k in llama_states: index_dict["weight_map"][k] = filename print(f"Saving to {os.path.join(tgt, filename)}...", flush=True) torch.save(llama_states, os.path.join(tgt, filename)) del states print("Saving config and tokenizer...") # index.json if index_dict is not None: with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp: json.dump(index_dict, fp, indent=2) # tokenizer tokenizer = LlamaTokenizer.from_pretrained(src) tokenizer.init_kwargs.pop("auto_map", None) tokenizer.save_pretrained(tgt) # config save_conifg(config, tgt) print("Done!") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, help="Input folder") parser.add_argument("--tgt", type=str, help="Output folder") args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() convert(args.src, args.tgt)