From b5f2a3ead4350ebb981bc61917a70122022b2151 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Tue, 12 Dec 2023 12:23:57 +0800 Subject: [PATCH] update convert script --- tools/transformers/convert2hf.py | 445 ++++++++++++++++++++----------- 1 file changed, 291 insertions(+), 154 deletions(-) diff --git a/tools/transformers/convert2hf.py b/tools/transformers/convert2hf.py index 167e02b..b1d88e0 100644 --- a/tools/transformers/convert2hf.py +++ b/tools/transformers/convert2hf.py @@ -1,192 +1,329 @@ +# Copyright (c) InternLM. All rights reserved. +""" +python convert2hf.py --src /path/to/src --tgt /path/to/tgt \ + --max_shard 2G --max_pos 8192 \ + --tokenizer /path/to/tokenizer.model \ +""" import argparse -import math +import gc import json import os import re -import tempfile +import time import torch -from internlm_model import InternLMConfig, InternLMForCausalLM -from internlm_model import InternLMTokenizer +from internlm_model import InternLMConfig, InternLMForCausalLM, InternLMTokenizer +from tqdm import tqdm +from transformers.modeling_utils import no_init_weights -NUM_SHARDS = { - "7B": 1, -} +embedding_key_list = ["embedding.word_embeddings.weight", "embedding.weight", "tok_embeddings.weight", None] -def convert2hf(model_config, states_tp_pps): +def _find_max_tp_pp(names): + ckpt_names = [] + for name in names: + if name.startswith("model_t") and not name.endswith("md5"): + # _t: avoid conflictint with model_config.pt + ckpt_names.append(name) - with tempfile.TemporaryDirectory() as folder: - states = merge_pp(states_tp_pps)[0] + max_tp, max_pp = -1, -1 + for ckpt in ckpt_names: + _, tp, pp = os.path.splitext(ckpt)[0].split("_") + max_tp = max(max_tp, int(tp[2:]) + 1) + max_pp = max(max_pp, int(pp[2:]) + 1) - if "embedding.word_embeddings.weight" in states: - embedding_key = "embedding.word_embeddings.weight" - elif "embedding.weight" in states: - embedding_key = "embedding.weight" - else: - print("Check embedding states'names in below:", flush=True) - print(list(states.keys()), flush=True) - - dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"] - base = 10000.0 - inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - - current_states = {} - - current_states["model.embed_tokens.weight"] = states.pop(embedding_key) - current_states["model.norm.weight"] = states.pop("norm.weight") - current_states["lm_head.weight"] = states.pop("head.weight") - - for i in range(model_config["num_layers"]): - states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None) - - wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( - 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] - ) - bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1) - - current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape( - -1, model_config["hidden_size"] - ) - current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1) - current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape( - -1, model_config["hidden_size"] - ) - current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1) - current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape( - -1, model_config["hidden_size"] - ) - current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1) - - current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop( - f"blocks.{i}.mixer.out_proj.weight" - ) - current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias") - - current_states[f"model.layers.{i}.mlp.gate_proj.weight"] = states.pop(f"blocks.{i}.mlp.w1.weight") - current_states[f"model.layers.{i}.mlp.down_proj.weight"] = states.pop(f"blocks.{i}.mlp.w3.weight") - current_states[f"model.layers.{i}.mlp.up_proj.weight"] = states.pop(f"blocks.{i}.mlp.w2.weight") - - current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight") - current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight") - current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq - - config = InternLMConfig( - hidden_size=model_config["hidden_size"], - intermediate_size=compute_intermediate_size(model_config["hidden_size"]), - num_attention_heads=model_config["num_attention_heads"], - num_hidden_layers=model_config["num_layers"], - rms_norm_eps=1e-06, - bias=True, - ) - - if model_config["vocab_size"] != -1: - config.vocab_size = model_config["vocab_size"] - - config.save_pretrained(folder) - torch.save(current_states, os.path.join(folder, "pytorch_model.bin")) - - model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16) - del model.config._name_or_path - - return config, model + return max_tp, max_pp -def compute_intermediate_size(n): - return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 +def load_source(src): + """ + load model_config.pt and model_tp{x}_pp{x}.pt from ``src`` + + :return: + - model_config: dict + - states: 2-d array. states[i][j] stands for state_dict of tp_i pp_j + """ + + # config + print("Config loading", flush=True) + config_file = os.path.join(src, "model_config.pt") + assert os.path.isfile(config_file), f"model_config.pt is not found in :{os.listdir(src)}" + model_config = torch.load(config_file) + print(model_config) + print("Config loaded.", flush=True) + + # checkpoint + # find tp pp + assert os.path.isdir(src), "not a folder." + ckpt_names = os.listdir(src) + max_tp, max_pp = _find_max_tp_pp(ckpt_names) + + # 2-d array tp_rank, pp_rank + print("Source Checkpoint Loading", flush=True) + states = [[None for _ in range(max_pp)] for __ in range(max_tp)] + for tp in tqdm(range(max_tp)): + for pp in tqdm(range(max_pp)): + ckpt_name = os.path.join(src, f"model_tp{tp}_pp{pp}.pt") + states[tp][pp] = torch.load(ckpt_name, map_location="cpu") + print("Source Checkpoint Loaded", flush=True) + return model_config, states -def merge_pp(states_tp_pp): - max_tp = len(states_tp_pp) - max_pp = len(states_tp_pp[0]) +def merge(states): + """ + Merge state dicts of pipeline format and shift some layers. - full_states = [] - for tp in range(max_tp): + :return: + - config: InternLMConfig + - states: merged state dict + """ + # merge pp + merged_states = [] + print("Pipeline Merging", flush=True) + for tp_state in tqdm(states): layer_shift = 0 - - tp_states = {} - for pp in range(max_pp): + shifted_state = {} + # shift key + for tp_pp_state in tp_state: _layer_shift = 0 - states = states_tp_pp[tp][pp] - keys = list(states.keys()) + keys = list(tp_pp_state.keys()) for key in keys: - match = re.search("\.\d+\.", key) + if key.endswith(".inv_freq"): + continue + match = re.search(r"\.\d+\.", key) + name = key if match is not None: + # layers s, e = match.span() layer_idx = int(key[s + 1 : e - 1]) + layer_shift _layer_shift = max(_layer_shift, int(key[s + 1 : e - 1])) name = key[:s] + f".{layer_idx}." + key[e:] - tp_states[name] = states[key] - else: - tp_states[key] = states[key] + if name.startswith("model."): + name = name[6:] + shifted_state[name] = tp_pp_state[key] layer_shift += _layer_shift + 1 - full_states.append({(key[6:] if key.startswith("model.") else key): value for key, value in tp_states.items()}) - return full_states + + merged_states.append(shifted_state) + + print("Pipeline Merged", flush=True) + + return merged_states + + +def convert(src, tgt, tokenizer, dtype, max_shard_size, max_pos, rope_scaling): + """ + Convert state_dict to hf format. + + 1. Load and merge state dict + 2. Convert to huggingface + 3. Load tokneizer and save it with ``tokenizer.save_pretrained`` + 4. Load state dict to the model + 5. Call ``model.save_pretrained`` to save checkpoints. + """ + # load states + model_config, src_states = load_source(src) + states = merge(src_states) + del src_states + + num_shards = len(states) + print("Converting to huggingface format...", flush=True) + + n_heads = model_config["num_attention_heads"] + dim = model_config["hidden_size"] + # n_heads_per_shard = n_heads // num_shards + # dims_per_head = dim // n_heads + intermediate_size = None + + print("Start converting...", flush=True) + state_dict = {} + for layer_i in tqdm(range(model_config["num_layers"])): + wqkvs = [ + states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.weight").reshape(3, n_heads // num_shards, -1, dim) + for tp in range(num_shards) + ] + bqkvs = [ + states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.bias").reshape(3, n_heads // num_shards, -1) + for tp in range(num_shards) + ] + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][f"blocks.{layer_i}.norm1.weight"].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"blocks.{layer_i}.norm2.weight" + ].clone(), + } + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [wqkvs[i][0] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat( + [bqkvs[i][0] for i in range(num_shards)], + dim=0, + ).reshape(-1) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [wqkvs[i][1] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat( + [bqkvs[i][1] for i in range(num_shards)], + dim=0, + ).reshape(-1) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [wqkvs[i][2] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat( + [bqkvs[i][2] for i in range(num_shards)], + dim=0, + ).reshape(-1) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"blocks.{layer_i}.mixer.out_proj.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.bias"] = states[0][f"blocks.{layer_i}.mixer.out_proj.bias"] + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"blocks.{layer_i}.mlp.w1.weight"] for i in range(num_shards)], dim=0 + ) + intermediate_size, _ = state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"].shape + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"blocks.{layer_i}.mlp.w3.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"blocks.{layer_i}.mlp.w2.weight"] for i in range(num_shards)], dim=0 + ) + + # embedding + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["head.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # initialize model + # tokenizer + tokenizer = InternLMTokenizer(tokenizer) + # config + config = InternLMConfig( + vocab_size=model_config["vocab_size"], + hidden_size=model_config["hidden_size"], + intermediate_size=intermediate_size, + num_attention_heads=model_config["num_attention_heads"], + num_hidden_layers=model_config["num_layers"], + rms_norm_eps=model_config["layer_norm_epsilon"], + bias=True, + rope_theta=model_config.get("rope_base", 10000), + rope_scaling=rope_scaling, + ) + # tokenizer + config.max_position_embeddings = max_pos + # set bos eos pad to avoid improper generation + # since model.generate will create attention_mask + # according to pad_token_id and bos_token_id + config.bos_token_id = tokenizer.bos_token_id + config.eos_token_id = tokenizer.eos_token_id + config.pad_token_id = tokenizer.pad_token_id + + # model + print("Initializing model...", flush=True) + start = time.time() + with no_init_weights(): + model = InternLMForCausalLM._from_config(config, torch_dtype=dtype) + print(f"Initializing model takes {time.time() - start}s", flush=True) + model.load_state_dict(state_dict) + + del states + gc.collect() + print(f"Saving model to {tgt}...", flush=True) + tokenizer.save_pretrained(tgt) + model.save_pretrained(tgt, max_shard_size=max_shard_size) + + # fix auto_map in config + with open(os.path.join(tgt, "config.json")) as fp: + config_dict = json.load(fp) + config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMForCausalLM" + with open(os.path.join(tgt, "config.json"), "w") as fp: + json.dump(config_dict, fp, indent=2) + + +def convert_tokenizer(src, tgt): + assert os.path.isfile(src) + tokenizer = InternLMTokenizer(src) + tokenizer.save_pretrained(tgt) + + +def get_rope_scaling(args): + if args.rotary_type == "origin": + return None + elif args.rotary_type == "dynamic": + return {"type": args.rotary_type, "factor": args.scaling_factor} + else: + raise NotImplementedError(f"Unknown rope type {args.rotary_type}") def print_args(args): print("-------------- Arguments --------------") - print(f"Source Path: {args.src_folder}") - print(f"Target Path: {args.tgt_folder}") + print(f"Source Path: {args.src}") + print(f"Target Path: {args.tgt}") print(f"Dtype: {args.dtype}") print(f"Max Shard Size: {args.max_shard}") + print(f"Max Position Embedding: {args.max_pos}") print(f"Tokenizer Path: {args.tokenizer}") + print(f"Rotary Type: {args.rotary_type}") + print(f"Scaling Factor: {args.scaling_factor}") print("---------------------------------------") -if __name__ == "__main__": + +def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--src_folder', type=str, default='~/test/') # 需要转换为hf格式的checkpoint文件夹 - parser.add_argument('--tgt_folder', type=str, default='~/output/') # 存放转换后checkpoint的目标文件夹 - parser.add_argument('--tokenizer', type=str, default='~/test/tokenizer.model') # Tokenizer 文件的路径 - parser.add_argument("--dtype", type=str, default="float16") # 转换后模型的 dtype - parser.add_argument("--max_shard", type=str, default="10GB") # 转换后模型每个切片的大小 + # model + parser.add_argument("--src", type=str, default=None, help="Input folder") + parser.add_argument("--tgt", type=str, help="Output folder") + parser.add_argument("--dtype", default="bfloat16", type=str, help="Data type after converting") + parser.add_argument("--max_shard", type=str, default="10GB", help="Max size of every sharded checkpoint.") + parser.add_argument("--max_pos", type=int, default=4096, help="Max position embedding of model.") + # tokenizer + parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer model.") + # rope + parser.add_argument("--rotary_type", type=str, default="origin", help="Rope type", choices=["origin", "dynamic"]) + parser.add_argument("--scaling_factor", type=float, default=1.0, help="Scaling factor of dynamic rope.") args = parser.parse_args() - dtype = getattr(torch, args.dtype) + + return args + + +if __name__ == "__main__": + args = parse_args() print_args(args) + dtype = getattr(torch, args.dtype) + rope_scaling = get_rope_scaling(args) - def load(fp): - with open(fp, "rb") as f: - pt_data = torch.load(f, map_location="cpu") - return pt_data - - folder = args.src_folder - target_folder = args.tgt_folder - - tokenizer = InternLMTokenizer(args.tokenizer) - tokenizer.save_pretrained(target_folder) - - model_config = load(os.path.join(folder, "model_config.pt")) - - fns = list(os.listdir(folder)) - - model_fns = [] - for fn in fns: - if fn.startswith("model_t") and not fn.endswith("md5"): - model_fns.append(fn) - - max_tp, max_pp = -1, -1 - for fn in model_fns: - _, tp, pp = os.path.splitext(fn)[0].split("_") - max_pp = max(max_pp, int(pp[2:]) + 1) - max_tp = max(max_tp, int(tp[2:]) + 1) - - states_tp_pps = [[]] - - for pp in range(max_pp): - model_name = f"model_tp0_pp{pp}.pt" - states = load(os.path.join(folder, model_name)) - states_tp_pps[0].append(states) - - config, model = convert2hf(model_config, states_tp_pps) - model.config.bos_token_id = tokenizer.bos_token_id - model.config.eos_token_id = tokenizer.eos_token_id - model.config.pad_token_id = tokenizer.pad_token_id - - os.makedirs(target_folder, exist_ok=True) - model.save_pretrained(target_folder, max_shard_size=args.max_shard, torch_dtype=dtype) - # TODO There should be a better way to add this. - with open(os.path.join(target_folder, "config.json")) as fp: - config_dict = json.load(fp) - config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMForCausalLM" - with open(os.path.join(target_folder, "config.json"), "w") as fp: - json.dump(config_dict, fp, indent=2) + assert args.src is not None, "--src is needed!" + assert args.tokenizer is not None, "--tokenizer is needed!" + start = time.time() + convert(args.src, args.tgt, args.tokenizer, dtype, args.max_shard, args.max_pos, rope_scaling) + print(f"Converting model takes {time.time() - start}s totally", flush=True)