2024-01-19 11:47:28 +00:00
|
|
|
# Copyright (c) InternLM. All rights reserved.
|
|
|
|
import argparse
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from einops import rearrange
|
|
|
|
from tqdm import tqdm
|
|
|
|
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def save_conifg(config, tgt):
|
|
|
|
config_dict = config.to_dict()
|
|
|
|
unnecessary_keys = [
|
2024-01-26 09:26:04 +00:00
|
|
|
'_name_or_path',
|
|
|
|
'auto_map',
|
|
|
|
'transformers_version',
|
|
|
|
'model_type',
|
|
|
|
'architectures',
|
|
|
|
'tokenizer_class',
|
|
|
|
'attn_implementation',
|
2024-01-19 11:47:28 +00:00
|
|
|
]
|
|
|
|
for k in unnecessary_keys:
|
|
|
|
config_dict.pop(k, None)
|
2024-01-26 09:26:04 +00:00
|
|
|
config_dict['attention_bias'] = config_dict.pop('bias')
|
|
|
|
config_dict['architectures'] = ['LlamaForCausalLM']
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_config = LlamaConfig(**config_dict)
|
|
|
|
llama_config.save_pretrained(tgt)
|
|
|
|
|
|
|
|
|
|
|
|
def convert(src, tgt):
|
|
|
|
"""Convert InternLM2 huggingface checkpoints to Llama-style."""
|
|
|
|
|
2024-01-26 09:26:04 +00:00
|
|
|
print('Convert InternLM2 huggingface checkpoints to Llama...')
|
2024-01-19 11:47:28 +00:00
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(src, trust_remote_code=True)
|
2024-01-26 09:26:04 +00:00
|
|
|
assert not config.bias, 'Cannot convert InternLM Model with bias to LLaMA.'
|
2024-01-19 11:47:28 +00:00
|
|
|
|
|
|
|
head_dim = config.hidden_size // config.num_attention_heads
|
2024-01-26 09:26:04 +00:00
|
|
|
num_key_value_groups = config.num_attention_heads \
|
|
|
|
// config.num_key_value_heads
|
2024-01-19 11:47:28 +00:00
|
|
|
|
|
|
|
# load index json file
|
2024-01-26 09:26:04 +00:00
|
|
|
index_file = os.path.join(src, 'pytorch_model.bin.index.json')
|
2024-01-19 11:47:28 +00:00
|
|
|
if os.path.exists(index_file):
|
|
|
|
with open(index_file) as fp:
|
|
|
|
index_dict = json.load(fp)
|
2024-01-26 09:26:04 +00:00
|
|
|
index_dict['weight_map'] = {}
|
2024-01-19 11:47:28 +00:00
|
|
|
else:
|
|
|
|
index_dict = None
|
|
|
|
|
|
|
|
os.makedirs(tgt, exist_ok=True)
|
|
|
|
for filename in tqdm(os.listdir(src)):
|
2024-01-26 09:26:04 +00:00
|
|
|
if not filename.endswith('.bin'):
|
2024-01-19 11:47:28 +00:00
|
|
|
continue
|
|
|
|
states = torch.load(os.path.join(src, filename))
|
|
|
|
llama_states = {}
|
|
|
|
for k, v in states.copy().items():
|
2024-01-26 09:26:04 +00:00
|
|
|
if 'wqkv' in k:
|
2024-01-19 11:47:28 +00:00
|
|
|
v = rearrange(
|
|
|
|
v,
|
2024-01-26 09:26:04 +00:00
|
|
|
'(h gs d) dim -> h gs d dim',
|
2024-01-19 11:47:28 +00:00
|
|
|
gs=2 + num_key_value_groups,
|
|
|
|
d=head_dim,
|
|
|
|
)
|
2024-01-26 09:26:04 +00:00
|
|
|
wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1],
|
|
|
|
dim=1)
|
|
|
|
wq = rearrange(wq, 'h gs d dim -> (h gs d) dim')
|
|
|
|
wk = rearrange(wk, 'h gs d dim -> (h gs d) dim')
|
|
|
|
wv = rearrange(wv, 'h gs d dim -> (h gs d) dim')
|
|
|
|
_prefix = k.split('attention')[0]
|
|
|
|
wq_key = _prefix + 'self_attn.q_proj.weight'
|
|
|
|
wk_key = _prefix + 'self_attn.k_proj.weight'
|
|
|
|
wv_key = _prefix + 'self_attn.v_proj.weight'
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[wq_key] = wq.clone()
|
|
|
|
llama_states[wk_key] = wk.clone()
|
|
|
|
llama_states[wv_key] = wv.clone()
|
|
|
|
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'attention.wo' in k:
|
|
|
|
new_k = k.replace('attention.wo', 'self_attn.o_proj')
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[new_k] = v
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'feed_forward.w1' in k:
|
|
|
|
new_k = k.replace('feed_forward.w1', 'mlp.gate_proj')
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[new_k] = v
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'feed_forward.w2' in k:
|
|
|
|
new_k = k.replace('feed_forward.w2', 'mlp.down_proj')
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[new_k] = v
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'feed_forward.w3' in k:
|
|
|
|
new_k = k.replace('feed_forward.w3', 'mlp.up_proj')
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[new_k] = v
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'attention_norm' in k:
|
|
|
|
new_k = k.replace('attention_norm', 'input_layernorm')
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[new_k] = v
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'ffn_norm' in k:
|
|
|
|
new_k = k.replace('ffn_norm', 'post_attention_layernorm')
|
2024-01-19 11:47:28 +00:00
|
|
|
llama_states[new_k] = v
|
2024-01-26 09:26:04 +00:00
|
|
|
elif 'tok_embeddings' in k:
|
|
|
|
llama_states['model.embed_tokens.weight'] = v
|
|
|
|
elif 'output' in k:
|
|
|
|
llama_states['lm_head.weight'] = v
|
2024-01-19 11:47:28 +00:00
|
|
|
else:
|
|
|
|
llama_states[k] = v
|
|
|
|
|
|
|
|
if index_dict is not None:
|
|
|
|
for k in llama_states:
|
2024-01-26 09:26:04 +00:00
|
|
|
index_dict['weight_map'][k] = filename
|
2024-01-19 11:47:28 +00:00
|
|
|
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
|
|
|
torch.save(llama_states, os.path.join(tgt, filename))
|
|
|
|
del states
|
|
|
|
|
2024-01-26 09:26:04 +00:00
|
|
|
print('Saving config and tokenizer...')
|
2024-01-19 11:47:28 +00:00
|
|
|
# index.json
|
|
|
|
if index_dict is not None:
|
2024-01-26 09:26:04 +00:00
|
|
|
with open(os.path.join(tgt, 'pytorch_model.bin.index.json'),
|
|
|
|
'w') as fp:
|
2024-01-19 11:47:28 +00:00
|
|
|
json.dump(index_dict, fp, indent=2)
|
|
|
|
# tokenizer
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(src)
|
2024-01-26 09:26:04 +00:00
|
|
|
tokenizer.init_kwargs.pop('auto_map', None)
|
2024-01-19 11:47:28 +00:00
|
|
|
tokenizer.save_pretrained(tgt)
|
|
|
|
# config
|
|
|
|
save_conifg(config, tgt)
|
2024-01-26 09:26:04 +00:00
|
|
|
print('Done!')
|
2024-01-19 11:47:28 +00:00
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser()
|
2024-01-26 09:26:04 +00:00
|
|
|
parser.add_argument('--src', type=str, help='Input folder')
|
|
|
|
parser.add_argument('--tgt', type=str, help='Output folder')
|
2024-01-19 11:47:28 +00:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
2024-01-26 09:26:04 +00:00
|
|
|
if __name__ == '__main__':
|
2024-01-19 11:47:28 +00:00
|
|
|
args = parse_args()
|
|
|
|
|
|
|
|
convert(args.src, args.tgt)
|