lint check

pull/627/head
x54-729 2024-01-19 17:32:42 +08:00
parent 3d609d8e38
commit d99ba98663
1 changed files with 14 additions and 8 deletions

View File

@ -1,16 +1,25 @@
# Copyright (c) InternLM. All rights reserved. # Copyright (c) InternLM. All rights reserved.
import argparse import argparse
import os
import json import json
import os
import torch import torch
from einops import rearrange from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, LlamaTokenizer, LlamaConfig from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
def save_conifg(config, tgt): def save_conifg(config, tgt):
config_dict = config.to_dict() config_dict = config.to_dict()
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"] unnecessary_keys = [
"_name_or_path",
"auto_map",
"transformers_version",
"model_type",
"architectures",
"tokenizer_class",
"attn_implementation",
]
for k in unnecessary_keys: for k in unnecessary_keys:
config_dict.pop(k, None) config_dict.pop(k, None)
config_dict["attention_bias"] = config_dict.pop("bias") config_dict["attention_bias"] = config_dict.pop("bias")
@ -29,7 +38,6 @@ def convert(src, tgt):
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
# load index json file # load index json file
index_file = os.path.join(src, "pytorch_model.bin.index.json") index_file = os.path.join(src, "pytorch_model.bin.index.json")
@ -54,9 +62,7 @@ def convert(src, tgt):
gs=2 + num_key_value_groups, gs=2 + num_key_value_groups,
d=head_dim, d=head_dim,
) )
wq, wk, wv = torch.split( wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1)
v, [num_key_value_groups, 1, 1], dim=1
)
wq = rearrange(wq, "h gs d dim -> (h gs d) dim") wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
wk = rearrange(wk, "h gs d dim -> (h gs d) dim") wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
wv = rearrange(wv, "h gs d dim -> (h gs d) dim") wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
@ -94,7 +100,7 @@ def convert(src, tgt):
llama_states[k] = v llama_states[k] = v
if index_dict is not None: if index_dict is not None:
for k in llama_states.keys(): for k in llama_states:
index_dict["weight_map"][k] = filename index_dict["weight_map"][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True) print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
torch.save(llama_states, os.path.join(tgt, filename)) torch.save(llama_states, os.path.join(tgt, filename))