diff --git a/tools/convert2llama.py b/tools/convert2llama.py index 1b6c25d..7e156da 100644 --- a/tools/convert2llama.py +++ b/tools/convert2llama.py @@ -1,16 +1,25 @@ # Copyright (c) InternLM. All rights reserved. import argparse -import os import json +import os import torch from einops import rearrange from tqdm import tqdm -from transformers import AutoConfig, LlamaTokenizer, LlamaConfig +from transformers import AutoConfig, LlamaConfig, LlamaTokenizer + def save_conifg(config, tgt): config_dict = config.to_dict() - unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"] + unnecessary_keys = [ + "_name_or_path", + "auto_map", + "transformers_version", + "model_type", + "architectures", + "tokenizer_class", + "attn_implementation", + ] for k in unnecessary_keys: config_dict.pop(k, None) config_dict["attention_bias"] = config_dict.pop("bias") @@ -29,7 +38,6 @@ def convert(src, tgt): head_dim = config.hidden_size // config.num_attention_heads num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - # load index json file index_file = os.path.join(src, "pytorch_model.bin.index.json") @@ -54,9 +62,7 @@ def convert(src, tgt): gs=2 + num_key_value_groups, d=head_dim, ) - wq, wk, wv = torch.split( - v, [num_key_value_groups, 1, 1], dim=1 - ) + wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1) wq = rearrange(wq, "h gs d dim -> (h gs d) dim") wk = rearrange(wk, "h gs d dim -> (h gs d) dim") wv = rearrange(wv, "h gs d dim -> (h gs d) dim") @@ -94,7 +100,7 @@ def convert(src, tgt): llama_states[k] = v if index_dict is not None: - for k in llama_states.keys(): + for k in llama_states: index_dict["weight_map"][k] = filename print(f"Saving to {os.path.join(tgt, filename)}...", flush=True) torch.save(llama_states, os.path.join(tgt, filename))