remove --tp_size; rearrange wqkv

pull/627/head
x54-729 2024-01-19 17:29:40 +08:00
parent d47962c2d0
commit 3d609d8e38
1 changed files with 80 additions and 75 deletions

View File

@ -1,110 +1,116 @@
# Copyright (c) InternLM. All rights reserved. # Copyright (c) InternLM. All rights reserved.
import argparse import argparse
import os import os
from collections import defaultdict import json
import torch import torch
from einops import rearrange from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig from transformers import AutoConfig, LlamaTokenizer, LlamaConfig
def save_conifg(config, tgt):
config_dict = config.to_dict()
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"]
for k in unnecessary_keys:
config_dict.pop(k, None)
config_dict["attention_bias"] = config_dict.pop("bias")
config_dict["architectures"] = ["LlamaForCausalLM"]
llama_config = LlamaConfig(**config_dict)
llama_config.save_pretrained(tgt)
def split_wqkv(qkv, num_groups, q_per_kv, head_dim): def convert(src, tgt):
"""Split wqkv into wq, wk, wv."""
qkv = qkv.T
qkv = rearrange(qkv, "o (g n i) -> o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
q = qkv[..., :q_per_kv, :]
k = qkv[..., q_per_kv : q_per_kv + 1, :]
v = qkv[..., q_per_kv + 1 : q_per_kv + 2, :]
q = rearrange(q, "o g n i -> o (g n i)", g=num_groups, n=q_per_kv, i=head_dim)
k = rearrange(k, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim)
v = rearrange(v, "o g n i -> o (g n i)", g=num_groups, n=1, i=head_dim)
return q.T, k.T, v.T
def convert(src, tgt, tp_size):
"""Convert InternLM2 huggingface checkpoints to Llama-style.""" """Convert InternLM2 huggingface checkpoints to Llama-style."""
print("Loading origin checkpoints...")
hf_states = []
hf_state_names = []
remain_files = []
for filename in tqdm(os.listdir(src)):
if not filename.endswith(".bin"):
remain_files.append(filename)
continue
hf_state_names.append(filename)
hf_states.append(torch.load(os.path.join(src, filename)))
print("Convert InternLM2 huggingface checkpoints to Llama...") print("Convert InternLM2 huggingface checkpoints to Llama...")
config = AutoConfig.from_pretrained(src, trust_remote_code=True) config = AutoConfig.from_pretrained(src, trust_remote_code=True)
assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA."
q_per_kv = config.num_attention_heads // config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
num_heads_per_tp = num_heads // tp_size
num_groups = num_heads_per_tp // q_per_kv
for states in tqdm(hf_states):
tmp_states = defaultdict(defaultdict) # load index json file
index_file = os.path.join(src, "pytorch_model.bin.index.json")
if os.path.exists(index_file):
with open(index_file) as fp:
index_dict = json.load(fp)
index_dict["weight_map"] = {}
else:
index_dict = None
os.makedirs(tgt, exist_ok=True)
for filename in tqdm(os.listdir(src)):
if not filename.endswith(".bin"):
continue
states = torch.load(os.path.join(src, filename))
llama_states = {}
for k, v in states.copy().items(): for k, v in states.copy().items():
if "wqkv" in k: if "wqkv" in k:
wqkvs = v.chunk(tp_size, 0) v = rearrange(
for i in range(tp_size): v,
wq, wk, wv = split_wqkv(wqkvs[i], num_groups, q_per_kv, head_dim) "(h gs d) dim -> h gs d dim",
gs=2 + num_key_value_groups,
d=head_dim,
)
wq, wk, wv = torch.split(
v, [num_key_value_groups, 1, 1], dim=1
)
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
_prefix = k.split("attention")[0] _prefix = k.split("attention")[0]
wq_key = _prefix + "self_attn.q_proj.weight" wq_key = _prefix + "self_attn.q_proj.weight"
wk_key = _prefix + "self_attn.k_proj.weight" wk_key = _prefix + "self_attn.k_proj.weight"
wv_key = _prefix + "self_attn.v_proj.weight" wv_key = _prefix + "self_attn.v_proj.weight"
llama_states[wq_key] = wq.clone()
tmp_states[wq_key][i] = wq.clone() llama_states[wk_key] = wk.clone()
tmp_states[wk_key][i] = wk.clone() llama_states[wv_key] = wv.clone()
tmp_states[wv_key][i] = wv.clone()
elif "attention.wo" in k: elif "attention.wo" in k:
new_k = k.replace("attention.wo", "self_attn.o_proj") new_k = k.replace("attention.wo", "self_attn.o_proj")
states[new_k] = v llama_states[new_k] = v
del states[k]
elif "feed_forward.w1" in k: elif "feed_forward.w1" in k:
new_k = k.replace("feed_forward.w1", "mlp.gate_proj") new_k = k.replace("feed_forward.w1", "mlp.gate_proj")
states[new_k] = v llama_states[new_k] = v
del states[k]
elif "feed_forward.w2" in k: elif "feed_forward.w2" in k:
new_k = k.replace("feed_forward.w2", "mlp.up_proj") new_k = k.replace("feed_forward.w2", "mlp.down_proj")
states[new_k] = v llama_states[new_k] = v
del states[k]
elif "feed_forward.w3" in k: elif "feed_forward.w3" in k:
new_k = k.replace("feed_forward.w3", "mlp.down_proj") new_k = k.replace("feed_forward.w3", "mlp.up_proj")
states[new_k] = v llama_states[new_k] = v
del states[k]
elif "attention_norm" in k: elif "attention_norm" in k:
new_k = k.replace("attention_norm", "input_layernorm") new_k = k.replace("attention_norm", "input_layernorm")
states[new_k] = v llama_states[new_k] = v
del states[k]
elif "ffn_norm" in k: elif "ffn_norm" in k:
new_k = k.replace("ffn_norm", "post_attention_layernorm") new_k = k.replace("ffn_norm", "post_attention_layernorm")
states[new_k] = v llama_states[new_k] = v
del states[k]
elif "tok_embeddings" in k: elif "tok_embeddings" in k:
states["model.embed_tokens.weight"] = v llama_states["model.embed_tokens.weight"] = v
del states[k]
elif "output" in k: elif "output" in k:
states["lm_head.weight"] = v llama_states["lm_head.weight"] = v
del states[k] else:
llama_states[k] = v
for k, v in tmp_states.items(): if index_dict is not None:
states[k] = torch.cat(list(v.values()), dim=0) for k in llama_states.keys():
index_dict["weight_map"][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
torch.save(llama_states, os.path.join(tgt, filename))
del states
os.makedirs(tgt, exist_ok=True) print("Saving config and tokenizer...")
for i, states in enumerate(hf_states): # index.json
print(f"Saving to {os.path.join(tgt, hf_state_names[i])}...", flush=True) if index_dict is not None:
torch.save(states, os.path.join(tgt, hf_state_names[i])) with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp:
for filename in remain_files: json.dump(index_dict, fp, indent=2)
print(f"Copying {filename}...", flush=True) # tokenizer
os.system(f"cp {os.path.join(src, filename)} {tgt}") tokenizer = LlamaTokenizer.from_pretrained(src)
tokenizer.init_kwargs.pop("auto_map", None)
tokenizer.save_pretrained(tgt)
# config
save_conifg(config, tgt)
print("Done!") print("Done!")
@ -112,7 +118,6 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, help="Input folder") parser.add_argument("--src", type=str, help="Input folder")
parser.add_argument("--tgt", type=str, help="Output folder") parser.add_argument("--tgt", type=str, help="Output folder")
parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel")
args = parser.parse_args() args = parser.parse_args()
@ -122,4 +127,4 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
convert(args.src, args.tgt, args.tp_size) convert(args.src, args.tgt)