remove --tp_size; rearrange wqkv

pull/627/head
x54-729 2024-01-19 17:29:40 +08:00
parent d47962c2d0
commit 3d609d8e38
1 changed files with 80 additions and 75 deletions

View File

@ -1,110 +1,116 @@
# Copyright (c) InternLM. All rights reserved.
import argparse
import os
from collections import defaultdict
import json
import torch
from einops import rearrange
from tqdm import tqdm
from transformers import AutoConfig
from transformers import AutoConfig, LlamaTokenizer, LlamaConfig
def save_conifg(config, tgt):
config_dict = config.to_dict()
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"]
for k in unnecessary_keys:
config_dict.pop(k, None)
config_dict["attention_bias"] = config_dict.pop("bias")
config_dict["architectures"] = ["LlamaForCausalLM"]
llama_config = LlamaConfig(**config_dict)
llama_config.save_pretrained(tgt)
def split_wqkv(qkv, num_groups, q_per_kv, head_dim):
"""Split wqkv into wq, wk, wv."""
qkv = qkv.T
qkv = rearrange(qkv, "o (g n i) -> o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
q = qkv[..., :q_per_kv, :]
k = qkv[..., q_per_kv : q_per_kv + 1, :]
v = qkv[..., q_per_kv + 1 : q_per_kv + 2, :]
q = rearrange(q, "o g n i -> o (g n i)", g=num_groups, n=q_per_kv, i=head_dim)
k = rearrange(k, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim)
v = rearrange(v, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim)
return q.T, k.T, v.T
def convert(src, tgt, tp_size):
def convert(src, tgt):
"""Convert InternLM2 huggingface checkpoints to Llama-style."""
print("Loading origin checkpoints...")
hf_states = []
hf_state_names = []
remain_files = []
for filename in tqdm(os.listdir(src)):
if not filename.endswith(".bin"):
remain_files.append(filename)
continue
hf_state_names.append(filename)
hf_states.append(torch.load(os.path.join(src, filename)))
print("Convert InternLM2 huggingface checkpoints to Llama...")
config = AutoConfig.from_pretrained(src, trust_remote_code=True)
assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA."
q_per_kv = config.num_attention_heads // config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads
num_heads_per_tp = num_heads // tp_size
num_groups = num_heads_per_tp // q_per_kv
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
for states in tqdm(hf_states):
tmp_states = defaultdict(defaultdict)
# load index json file
index_file = os.path.join(src, "pytorch_model.bin.index.json")
if os.path.exists(index_file):
with open(index_file) as fp:
index_dict = json.load(fp)
index_dict["weight_map"] = {}
else:
index_dict = None
os.makedirs(tgt, exist_ok=True)
for filename in tqdm(os.listdir(src)):
if not filename.endswith(".bin"):
continue
states = torch.load(os.path.join(src, filename))
llama_states = {}
for k, v in states.copy().items():
if "wqkv" in k:
wqkvs = v.chunk(tp_size, 0)
for i in range(tp_size):
wq, wk, wv = split_wqkv(wqkvs[i], num_groups, q_per_kv, head_dim)
_prefix = k.split("attention")[0]
wq_key = _prefix + "self_attn.q_proj.weight"
wk_key = _prefix + "self_attn.k_proj.weight"
wv_key = _prefix + "self_attn.v_proj.weight"
tmp_states[wq_key][i] = wq.clone()
tmp_states[wk_key][i] = wk.clone()
tmp_states[wv_key][i] = wv.clone()
v = rearrange(
v,
"(h gs d) dim -> h gs d dim",
gs=2 + num_key_value_groups,
d=head_dim,
)
wq, wk, wv = torch.split(
v, [num_key_value_groups, 1, 1], dim=1
)
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
_prefix = k.split("attention")[0]
wq_key = _prefix + "self_attn.q_proj.weight"
wk_key = _prefix + "self_attn.k_proj.weight"
wv_key = _prefix + "self_attn.v_proj.weight"
llama_states[wq_key] = wq.clone()
llama_states[wk_key] = wk.clone()
llama_states[wv_key] = wv.clone()
elif "attention.wo" in k:
new_k = k.replace("attention.wo", "self_attn.o_proj")
states[new_k] = v
del states[k]
llama_states[new_k] = v
elif "feed_forward.w1" in k:
new_k = k.replace("feed_forward.w1", "mlp.gate_proj")
states[new_k] = v
del states[k]
llama_states[new_k] = v
elif "feed_forward.w2" in k:
new_k = k.replace("feed_forward.w2", "mlp.up_proj")
states[new_k] = v
del states[k]
new_k = k.replace("feed_forward.w2", "mlp.down_proj")
llama_states[new_k] = v
elif "feed_forward.w3" in k:
new_k = k.replace("feed_forward.w3", "mlp.down_proj")
states[new_k] = v
del states[k]
new_k = k.replace("feed_forward.w3", "mlp.up_proj")
llama_states[new_k] = v
elif "attention_norm" in k:
new_k = k.replace("attention_norm", "input_layernorm")
states[new_k] = v
del states[k]
llama_states[new_k] = v
elif "ffn_norm" in k:
new_k = k.replace("ffn_norm", "post_attention_layernorm")
states[new_k] = v
del states[k]
llama_states[new_k] = v
elif "tok_embeddings" in k:
states["model.embed_tokens.weight"] = v
del states[k]
llama_states["model.embed_tokens.weight"] = v
elif "output" in k:
states["lm_head.weight"] = v
del states[k]
llama_states["lm_head.weight"] = v
else:
llama_states[k] = v
for k, v in tmp_states.items():
states[k] = torch.cat(list(v.values()), dim=0)
if index_dict is not None:
for k in llama_states.keys():
index_dict["weight_map"][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
torch.save(llama_states, os.path.join(tgt, filename))
del states
os.makedirs(tgt, exist_ok=True)
for i, states in enumerate(hf_states):
print(f"Saving to {os.path.join(tgt, hf_state_names[i])}...", flush=True)
torch.save(states, os.path.join(tgt, hf_state_names[i]))
for filename in remain_files:
print(f"Copying {filename}...", flush=True)
os.system(f"cp {os.path.join(src, filename)} {tgt}")
print("Saving config and tokenizer...")
# index.json
if index_dict is not None:
with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp:
json.dump(index_dict, fp, indent=2)
# tokenizer
tokenizer = LlamaTokenizer.from_pretrained(src)
tokenizer.init_kwargs.pop("auto_map", None)
tokenizer.save_pretrained(tgt)
# config
save_conifg(config, tgt)
print("Done!")
@ -112,7 +118,6 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, help="Input folder")
parser.add_argument("--tgt", type=str, help="Output folder")
parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel")
args = parser.parse_args()
@ -122,4 +127,4 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
convert(args.src, args.tgt, args.tp_size)
convert(args.src, args.tgt)