update convert script

pull/536/head
x54-729 2023-12-12 12:23:57 +08:00
parent d9262da635
commit b5f2a3ead4
1 changed files with 291 additions and 154 deletions

View File

@ -1,192 +1,329 @@
# Copyright (c) InternLM. All rights reserved.
"""
python convert2hf.py --src /path/to/src --tgt /path/to/tgt \
--max_shard 2G --max_pos 8192 \
--tokenizer /path/to/tokenizer.model \
"""
import argparse
import math
import gc
import json
import os
import re
import tempfile
import time
import torch
from internlm_model import InternLMConfig, InternLMForCausalLM
from internlm_model import InternLMTokenizer
from internlm_model import InternLMConfig, InternLMForCausalLM, InternLMTokenizer
from tqdm import tqdm
from transformers.modeling_utils import no_init_weights
NUM_SHARDS = {
"7B": 1,
}
embedding_key_list = ["embedding.word_embeddings.weight", "embedding.weight", "tok_embeddings.weight", None]
def convert2hf(model_config, states_tp_pps):
def _find_max_tp_pp(names):
ckpt_names = []
for name in names:
if name.startswith("model_t") and not name.endswith("md5"):
# _t: avoid conflictint with model_config.pt
ckpt_names.append(name)
with tempfile.TemporaryDirectory() as folder:
states = merge_pp(states_tp_pps)[0]
max_tp, max_pp = -1, -1
for ckpt in ckpt_names:
_, tp, pp = os.path.splitext(ckpt)[0].split("_")
max_tp = max(max_tp, int(tp[2:]) + 1)
max_pp = max(max_pp, int(pp[2:]) + 1)
if "embedding.word_embeddings.weight" in states:
embedding_key = "embedding.word_embeddings.weight"
elif "embedding.weight" in states:
embedding_key = "embedding.weight"
else:
print("Check embedding states'names in below:", flush=True)
print(list(states.keys()), flush=True)
dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"]
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
current_states = {}
current_states["model.embed_tokens.weight"] = states.pop(embedding_key)
current_states["model.norm.weight"] = states.pop("norm.weight")
current_states["lm_head.weight"] = states.pop("head.weight")
for i in range(model_config["num_layers"]):
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None)
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
3, model_config["num_attention_heads"], -1, model_config["hidden_size"]
)
bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1)
current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape(
-1, model_config["hidden_size"]
)
current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1)
current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape(
-1, model_config["hidden_size"]
)
current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1)
current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape(
-1, model_config["hidden_size"]
)
current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1)
current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop(
f"blocks.{i}.mixer.out_proj.weight"
)
current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias")
current_states[f"model.layers.{i}.mlp.gate_proj.weight"] = states.pop(f"blocks.{i}.mlp.w1.weight")
current_states[f"model.layers.{i}.mlp.down_proj.weight"] = states.pop(f"blocks.{i}.mlp.w3.weight")
current_states[f"model.layers.{i}.mlp.up_proj.weight"] = states.pop(f"blocks.{i}.mlp.w2.weight")
current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight")
current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight")
current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq
config = InternLMConfig(
hidden_size=model_config["hidden_size"],
intermediate_size=compute_intermediate_size(model_config["hidden_size"]),
num_attention_heads=model_config["num_attention_heads"],
num_hidden_layers=model_config["num_layers"],
rms_norm_eps=1e-06,
bias=True,
)
if model_config["vocab_size"] != -1:
config.vocab_size = model_config["vocab_size"]
config.save_pretrained(folder)
torch.save(current_states, os.path.join(folder, "pytorch_model.bin"))
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16)
del model.config._name_or_path
return config, model
return max_tp, max_pp
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def load_source(src):
"""
load model_config.pt and model_tp{x}_pp{x}.pt from ``src``
:return:
- model_config: dict
- states: 2-d array. states[i][j] stands for state_dict of tp_i pp_j
"""
# config
print("Config loading", flush=True)
config_file = os.path.join(src, "model_config.pt")
assert os.path.isfile(config_file), f"model_config.pt is not found in :{os.listdir(src)}"
model_config = torch.load(config_file)
print(model_config)
print("Config loaded.", flush=True)
# checkpoint
# find tp pp
assert os.path.isdir(src), "not a folder."
ckpt_names = os.listdir(src)
max_tp, max_pp = _find_max_tp_pp(ckpt_names)
# 2-d array tp_rank, pp_rank
print("Source Checkpoint Loading", flush=True)
states = [[None for _ in range(max_pp)] for __ in range(max_tp)]
for tp in tqdm(range(max_tp)):
for pp in tqdm(range(max_pp)):
ckpt_name = os.path.join(src, f"model_tp{tp}_pp{pp}.pt")
states[tp][pp] = torch.load(ckpt_name, map_location="cpu")
print("Source Checkpoint Loaded", flush=True)
return model_config, states
def merge_pp(states_tp_pp):
max_tp = len(states_tp_pp)
max_pp = len(states_tp_pp[0])
def merge(states):
"""
Merge state dicts of pipeline format and shift some layers.
full_states = []
for tp in range(max_tp):
:return:
- config: InternLMConfig
- states: merged state dict
"""
# merge pp
merged_states = []
print("Pipeline Merging", flush=True)
for tp_state in tqdm(states):
layer_shift = 0
tp_states = {}
for pp in range(max_pp):
shifted_state = {}
# shift key
for tp_pp_state in tp_state:
_layer_shift = 0
states = states_tp_pp[tp][pp]
keys = list(states.keys())
keys = list(tp_pp_state.keys())
for key in keys:
match = re.search("\.\d+\.", key)
if key.endswith(".inv_freq"):
continue
match = re.search(r"\.\d+\.", key)
name = key
if match is not None:
# layers
s, e = match.span()
layer_idx = int(key[s + 1 : e - 1]) + layer_shift
_layer_shift = max(_layer_shift, int(key[s + 1 : e - 1]))
name = key[:s] + f".{layer_idx}." + key[e:]
tp_states[name] = states[key]
else:
tp_states[key] = states[key]
if name.startswith("model."):
name = name[6:]
shifted_state[name] = tp_pp_state[key]
layer_shift += _layer_shift + 1
full_states.append({(key[6:] if key.startswith("model.") else key): value for key, value in tp_states.items()})
return full_states
merged_states.append(shifted_state)
print("Pipeline Merged", flush=True)
return merged_states
def convert(src, tgt, tokenizer, dtype, max_shard_size, max_pos, rope_scaling):
"""
Convert state_dict to hf format.
1. Load and merge state dict
2. Convert to huggingface
3. Load tokneizer and save it with ``tokenizer.save_pretrained``
4. Load state dict to the model
5. Call ``model.save_pretrained`` to save checkpoints.
"""
# load states
model_config, src_states = load_source(src)
states = merge(src_states)
del src_states
num_shards = len(states)
print("Converting to huggingface format...", flush=True)
n_heads = model_config["num_attention_heads"]
dim = model_config["hidden_size"]
# n_heads_per_shard = n_heads // num_shards
# dims_per_head = dim // n_heads
intermediate_size = None
print("Start converting...", flush=True)
state_dict = {}
for layer_i in tqdm(range(model_config["num_layers"])):
wqkvs = [
states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.weight").reshape(3, n_heads // num_shards, -1, dim)
for tp in range(num_shards)
]
bqkvs = [
states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.bias").reshape(3, n_heads // num_shards, -1)
for tp in range(num_shards)
]
state_dict.update(
{
f"model.layers.{layer_i}.input_layernorm.weight": states[0][f"blocks.{layer_i}.norm1.weight"].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
f"blocks.{layer_i}.norm2.weight"
].clone(),
}
)
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat(
[wqkvs[i][0] for i in range(num_shards)],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat(
[bqkvs[i][0] for i in range(num_shards)],
dim=0,
).reshape(-1)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat(
[wqkvs[i][1] for i in range(num_shards)],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat(
[bqkvs[i][1] for i in range(num_shards)],
dim=0,
).reshape(-1)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[wqkvs[i][2] for i in range(num_shards)],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat(
[bqkvs[i][2] for i in range(num_shards)],
dim=0,
).reshape(-1)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[states[i][f"blocks.{layer_i}.mixer.out_proj.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.bias"] = states[0][f"blocks.{layer_i}.mixer.out_proj.bias"]
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[states[i][f"blocks.{layer_i}.mlp.w1.weight"] for i in range(num_shards)], dim=0
)
intermediate_size, _ = state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"].shape
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[states[i][f"blocks.{layer_i}.mlp.w3.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[states[i][f"blocks.{layer_i}.mlp.w2.weight"] for i in range(num_shards)], dim=0
)
# embedding
for embedding_key in embedding_key_list:
if embedding_key in states[0]:
break
if embedding_key is None:
raise KeyError("Cannot find embedding key!")
if model_config["embed_split_hidden"]:
embed_concat_dim = 1
tok_emb_list = [states[i][embedding_key] for i in range(num_shards)]
else:
embed_concat_dim = 0
_, size_1 = states[0][embedding_key].shape
embdim_pertp = size_1 // num_shards
tok_emb_list = [
torch.concat(
[
states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)]
for tp in range(num_shards)
],
dim=0,
)
for local_rank in range(num_shards)
]
state_dict.update(
{
"model.norm.weight": states[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim),
"lm_head.weight": torch.cat([states[i]["head.weight"] for i in range(num_shards)], dim=0),
},
)
# initialize model
# tokenizer
tokenizer = InternLMTokenizer(tokenizer)
# config
config = InternLMConfig(
vocab_size=model_config["vocab_size"],
hidden_size=model_config["hidden_size"],
intermediate_size=intermediate_size,
num_attention_heads=model_config["num_attention_heads"],
num_hidden_layers=model_config["num_layers"],
rms_norm_eps=model_config["layer_norm_epsilon"],
bias=True,
rope_theta=model_config.get("rope_base", 10000),
rope_scaling=rope_scaling,
)
# tokenizer
config.max_position_embeddings = max_pos
# set bos eos pad to avoid improper generation
# since model.generate will create attention_mask
# according to pad_token_id and bos_token_id
config.bos_token_id = tokenizer.bos_token_id
config.eos_token_id = tokenizer.eos_token_id
config.pad_token_id = tokenizer.pad_token_id
# model
print("Initializing model...", flush=True)
start = time.time()
with no_init_weights():
model = InternLMForCausalLM._from_config(config, torch_dtype=dtype)
print(f"Initializing model takes {time.time() - start}s", flush=True)
model.load_state_dict(state_dict)
del states
gc.collect()
print(f"Saving model to {tgt}...", flush=True)
tokenizer.save_pretrained(tgt)
model.save_pretrained(tgt, max_shard_size=max_shard_size)
# fix auto_map in config
with open(os.path.join(tgt, "config.json")) as fp:
config_dict = json.load(fp)
config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMForCausalLM"
with open(os.path.join(tgt, "config.json"), "w") as fp:
json.dump(config_dict, fp, indent=2)
def convert_tokenizer(src, tgt):
assert os.path.isfile(src)
tokenizer = InternLMTokenizer(src)
tokenizer.save_pretrained(tgt)
def get_rope_scaling(args):
if args.rotary_type == "origin":
return None
elif args.rotary_type == "dynamic":
return {"type": args.rotary_type, "factor": args.scaling_factor}
else:
raise NotImplementedError(f"Unknown rope type {args.rotary_type}")
def print_args(args):
print("-------------- Arguments --------------")
print(f"Source Path: {args.src_folder}")
print(f"Target Path: {args.tgt_folder}")
print(f"Source Path: {args.src}")
print(f"Target Path: {args.tgt}")
print(f"Dtype: {args.dtype}")
print(f"Max Shard Size: {args.max_shard}")
print(f"Max Position Embedding: {args.max_pos}")
print(f"Tokenizer Path: {args.tokenizer}")
print(f"Rotary Type: {args.rotary_type}")
print(f"Scaling Factor: {args.scaling_factor}")
print("---------------------------------------")
if __name__ == "__main__":
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--src_folder', type=str, default='~/test/') # 需要转换为hf格式的checkpoint文件夹
parser.add_argument('--tgt_folder', type=str, default='~/output/') # 存放转换后checkpoint的目标文件夹
parser.add_argument('--tokenizer', type=str, default='~/test/tokenizer.model') # Tokenizer 文件的路径
parser.add_argument("--dtype", type=str, default="float16") # 转换后模型的 dtype
parser.add_argument("--max_shard", type=str, default="10GB") # 转换后模型每个切片的大小
# model
parser.add_argument("--src", type=str, default=None, help="Input folder")
parser.add_argument("--tgt", type=str, help="Output folder")
parser.add_argument("--dtype", default="bfloat16", type=str, help="Data type after converting")
parser.add_argument("--max_shard", type=str, default="10GB", help="Max size of every sharded checkpoint.")
parser.add_argument("--max_pos", type=int, default=4096, help="Max position embedding of model.")
# tokenizer
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer model.")
# rope
parser.add_argument("--rotary_type", type=str, default="origin", help="Rope type", choices=["origin", "dynamic"])
parser.add_argument("--scaling_factor", type=float, default=1.0, help="Scaling factor of dynamic rope.")
args = parser.parse_args()
dtype = getattr(torch, args.dtype)
return args
if __name__ == "__main__":
args = parse_args()
print_args(args)
dtype = getattr(torch, args.dtype)
rope_scaling = get_rope_scaling(args)
def load(fp):
with open(fp, "rb") as f:
pt_data = torch.load(f, map_location="cpu")
return pt_data
folder = args.src_folder
target_folder = args.tgt_folder
tokenizer = InternLMTokenizer(args.tokenizer)
tokenizer.save_pretrained(target_folder)
model_config = load(os.path.join(folder, "model_config.pt"))
fns = list(os.listdir(folder))
model_fns = []
for fn in fns:
if fn.startswith("model_t") and not fn.endswith("md5"):
model_fns.append(fn)
max_tp, max_pp = -1, -1
for fn in model_fns:
_, tp, pp = os.path.splitext(fn)[0].split("_")
max_pp = max(max_pp, int(pp[2:]) + 1)
max_tp = max(max_tp, int(tp[2:]) + 1)
states_tp_pps = [[]]
for pp in range(max_pp):
model_name = f"model_tp0_pp{pp}.pt"
states = load(os.path.join(folder, model_name))
states_tp_pps[0].append(states)
config, model = convert2hf(model_config, states_tp_pps)
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
os.makedirs(target_folder, exist_ok=True)
model.save_pretrained(target_folder, max_shard_size=args.max_shard, torch_dtype=dtype)
# TODO There should be a better way to add this.
with open(os.path.join(target_folder, "config.json")) as fp:
config_dict = json.load(fp)
config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMForCausalLM"
with open(os.path.join(target_folder, "config.json"), "w") as fp:
json.dump(config_dict, fp, indent=2)
assert args.src is not None, "--src is needed!"
assert args.tokenizer is not None, "--tokenizer is needed!"
start = time.time()
convert(args.src, args.tgt, args.tokenizer, dtype, args.max_shard, args.max_pos, rope_scaling)
print(f"Converting model takes {time.time() - start}s totally", flush=True)