mirror of https://github.com/InternLM/InternLM
remove --tp_size; rearrange wqkv
parent
d47962c2d0
commit
3d609d8e38
|
@ -1,110 +1,116 @@
|
||||||
# Copyright (c) InternLM. All rights reserved.
|
# Copyright (c) InternLM. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig, LlamaTokenizer, LlamaConfig
|
||||||
|
|
||||||
|
def save_conifg(config, tgt):
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"]
|
||||||
|
for k in unnecessary_keys:
|
||||||
|
config_dict.pop(k, None)
|
||||||
|
config_dict["attention_bias"] = config_dict.pop("bias")
|
||||||
|
config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||||
|
llama_config = LlamaConfig(**config_dict)
|
||||||
|
llama_config.save_pretrained(tgt)
|
||||||
|
|
||||||
|
|
||||||
def split_wqkv(qkv, num_groups, q_per_kv, head_dim):
|
def convert(src, tgt):
|
||||||
"""Split wqkv into wq, wk, wv."""
|
|
||||||
qkv = qkv.T
|
|
||||||
qkv = rearrange(qkv, "o (g n i) -> o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
|
||||||
|
|
||||||
q = qkv[..., :q_per_kv, :]
|
|
||||||
k = qkv[..., q_per_kv : q_per_kv + 1, :]
|
|
||||||
v = qkv[..., q_per_kv + 1 : q_per_kv + 2, :]
|
|
||||||
|
|
||||||
q = rearrange(q, "o g n i -> o (g n i)", g=num_groups, n=q_per_kv, i=head_dim)
|
|
||||||
k = rearrange(k, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim)
|
|
||||||
v = rearrange(v, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim)
|
|
||||||
return q.T, k.T, v.T
|
|
||||||
|
|
||||||
|
|
||||||
def convert(src, tgt, tp_size):
|
|
||||||
"""Convert InternLM2 huggingface checkpoints to Llama-style."""
|
"""Convert InternLM2 huggingface checkpoints to Llama-style."""
|
||||||
print("Loading origin checkpoints...")
|
|
||||||
hf_states = []
|
|
||||||
hf_state_names = []
|
|
||||||
remain_files = []
|
|
||||||
for filename in tqdm(os.listdir(src)):
|
|
||||||
if not filename.endswith(".bin"):
|
|
||||||
remain_files.append(filename)
|
|
||||||
continue
|
|
||||||
hf_state_names.append(filename)
|
|
||||||
hf_states.append(torch.load(os.path.join(src, filename)))
|
|
||||||
|
|
||||||
print("Convert InternLM2 huggingface checkpoints to Llama...")
|
print("Convert InternLM2 huggingface checkpoints to Llama...")
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(src, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(src, trust_remote_code=True)
|
||||||
|
assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA."
|
||||||
|
|
||||||
q_per_kv = config.num_attention_heads // config.num_key_value_heads
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
num_heads = config.num_attention_heads
|
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||||
num_heads_per_tp = num_heads // tp_size
|
|
||||||
num_groups = num_heads_per_tp // q_per_kv
|
|
||||||
|
|
||||||
for states in tqdm(hf_states):
|
|
||||||
tmp_states = defaultdict(defaultdict)
|
# load index json file
|
||||||
|
index_file = os.path.join(src, "pytorch_model.bin.index.json")
|
||||||
|
if os.path.exists(index_file):
|
||||||
|
with open(index_file) as fp:
|
||||||
|
index_dict = json.load(fp)
|
||||||
|
index_dict["weight_map"] = {}
|
||||||
|
else:
|
||||||
|
index_dict = None
|
||||||
|
|
||||||
|
os.makedirs(tgt, exist_ok=True)
|
||||||
|
for filename in tqdm(os.listdir(src)):
|
||||||
|
if not filename.endswith(".bin"):
|
||||||
|
continue
|
||||||
|
states = torch.load(os.path.join(src, filename))
|
||||||
|
llama_states = {}
|
||||||
for k, v in states.copy().items():
|
for k, v in states.copy().items():
|
||||||
if "wqkv" in k:
|
if "wqkv" in k:
|
||||||
wqkvs = v.chunk(tp_size, 0)
|
v = rearrange(
|
||||||
for i in range(tp_size):
|
v,
|
||||||
wq, wk, wv = split_wqkv(wqkvs[i], num_groups, q_per_kv, head_dim)
|
"(h gs d) dim -> h gs d dim",
|
||||||
|
gs=2 + num_key_value_groups,
|
||||||
|
d=head_dim,
|
||||||
|
)
|
||||||
|
wq, wk, wv = torch.split(
|
||||||
|
v, [num_key_value_groups, 1, 1], dim=1
|
||||||
|
)
|
||||||
|
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
|
||||||
|
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
|
||||||
|
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
|
||||||
_prefix = k.split("attention")[0]
|
_prefix = k.split("attention")[0]
|
||||||
wq_key = _prefix + "self_attn.q_proj.weight"
|
wq_key = _prefix + "self_attn.q_proj.weight"
|
||||||
wk_key = _prefix + "self_attn.k_proj.weight"
|
wk_key = _prefix + "self_attn.k_proj.weight"
|
||||||
wv_key = _prefix + "self_attn.v_proj.weight"
|
wv_key = _prefix + "self_attn.v_proj.weight"
|
||||||
|
llama_states[wq_key] = wq.clone()
|
||||||
tmp_states[wq_key][i] = wq.clone()
|
llama_states[wk_key] = wk.clone()
|
||||||
tmp_states[wk_key][i] = wk.clone()
|
llama_states[wv_key] = wv.clone()
|
||||||
tmp_states[wv_key][i] = wv.clone()
|
|
||||||
|
|
||||||
elif "attention.wo" in k:
|
elif "attention.wo" in k:
|
||||||
new_k = k.replace("attention.wo", "self_attn.o_proj")
|
new_k = k.replace("attention.wo", "self_attn.o_proj")
|
||||||
states[new_k] = v
|
llama_states[new_k] = v
|
||||||
del states[k]
|
|
||||||
elif "feed_forward.w1" in k:
|
elif "feed_forward.w1" in k:
|
||||||
new_k = k.replace("feed_forward.w1", "mlp.gate_proj")
|
new_k = k.replace("feed_forward.w1", "mlp.gate_proj")
|
||||||
states[new_k] = v
|
llama_states[new_k] = v
|
||||||
del states[k]
|
|
||||||
elif "feed_forward.w2" in k:
|
elif "feed_forward.w2" in k:
|
||||||
new_k = k.replace("feed_forward.w2", "mlp.up_proj")
|
new_k = k.replace("feed_forward.w2", "mlp.down_proj")
|
||||||
states[new_k] = v
|
llama_states[new_k] = v
|
||||||
del states[k]
|
|
||||||
elif "feed_forward.w3" in k:
|
elif "feed_forward.w3" in k:
|
||||||
new_k = k.replace("feed_forward.w3", "mlp.down_proj")
|
new_k = k.replace("feed_forward.w3", "mlp.up_proj")
|
||||||
states[new_k] = v
|
llama_states[new_k] = v
|
||||||
del states[k]
|
|
||||||
elif "attention_norm" in k:
|
elif "attention_norm" in k:
|
||||||
new_k = k.replace("attention_norm", "input_layernorm")
|
new_k = k.replace("attention_norm", "input_layernorm")
|
||||||
states[new_k] = v
|
llama_states[new_k] = v
|
||||||
del states[k]
|
|
||||||
elif "ffn_norm" in k:
|
elif "ffn_norm" in k:
|
||||||
new_k = k.replace("ffn_norm", "post_attention_layernorm")
|
new_k = k.replace("ffn_norm", "post_attention_layernorm")
|
||||||
states[new_k] = v
|
llama_states[new_k] = v
|
||||||
del states[k]
|
|
||||||
elif "tok_embeddings" in k:
|
elif "tok_embeddings" in k:
|
||||||
states["model.embed_tokens.weight"] = v
|
llama_states["model.embed_tokens.weight"] = v
|
||||||
del states[k]
|
|
||||||
elif "output" in k:
|
elif "output" in k:
|
||||||
states["lm_head.weight"] = v
|
llama_states["lm_head.weight"] = v
|
||||||
del states[k]
|
else:
|
||||||
|
llama_states[k] = v
|
||||||
|
|
||||||
for k, v in tmp_states.items():
|
if index_dict is not None:
|
||||||
states[k] = torch.cat(list(v.values()), dim=0)
|
for k in llama_states.keys():
|
||||||
|
index_dict["weight_map"][k] = filename
|
||||||
|
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
||||||
|
torch.save(llama_states, os.path.join(tgt, filename))
|
||||||
|
del states
|
||||||
|
|
||||||
os.makedirs(tgt, exist_ok=True)
|
print("Saving config and tokenizer...")
|
||||||
for i, states in enumerate(hf_states):
|
# index.json
|
||||||
print(f"Saving to {os.path.join(tgt, hf_state_names[i])}...", flush=True)
|
if index_dict is not None:
|
||||||
torch.save(states, os.path.join(tgt, hf_state_names[i]))
|
with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp:
|
||||||
for filename in remain_files:
|
json.dump(index_dict, fp, indent=2)
|
||||||
print(f"Copying {filename}...", flush=True)
|
# tokenizer
|
||||||
os.system(f"cp {os.path.join(src, filename)} {tgt}")
|
tokenizer = LlamaTokenizer.from_pretrained(src)
|
||||||
|
tokenizer.init_kwargs.pop("auto_map", None)
|
||||||
|
tokenizer.save_pretrained(tgt)
|
||||||
|
# config
|
||||||
|
save_conifg(config, tgt)
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,7 +118,6 @@ def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--src", type=str, help="Input folder")
|
parser.add_argument("--src", type=str, help="Input folder")
|
||||||
parser.add_argument("--tgt", type=str, help="Output folder")
|
parser.add_argument("--tgt", type=str, help="Output folder")
|
||||||
parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -122,4 +127,4 @@ def parse_args():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
convert(args.src, args.tgt, args.tp_size)
|
convert(args.src, args.tgt)
|
||||||
|
|
Loading…
Reference in New Issue