lint check

pull/627/head
x54-729 2024-01-19 17:32:42 +08:00
parent 3d609d8e38
commit d99ba98663
1 changed files with 14 additions and 8 deletions

View File

@ -1,16 +1,25 @@
# Copyright (c) InternLM. All rights reserved.
import argparse
import os
import json
import os
import torch
from einops import rearrange
from tqdm import tqdm
from transformers import AutoConfig, LlamaTokenizer, LlamaConfig
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
def save_conifg(config, tgt):
config_dict = config.to_dict()
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"]
unnecessary_keys = [
"_name_or_path",
"auto_map",
"transformers_version",
"model_type",
"architectures",
"tokenizer_class",
"attn_implementation",
]
for k in unnecessary_keys:
config_dict.pop(k, None)
config_dict["attention_bias"] = config_dict.pop("bias")
@ -29,7 +38,6 @@ def convert(src, tgt):
head_dim = config.hidden_size // config.num_attention_heads
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
# load index json file
index_file = os.path.join(src, "pytorch_model.bin.index.json")
@ -54,9 +62,7 @@ def convert(src, tgt):
gs=2 + num_key_value_groups,
d=head_dim,
)
wq, wk, wv = torch.split(
v, [num_key_value_groups, 1, 1], dim=1
)
wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1)
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
@ -94,7 +100,7 @@ def convert(src, tgt):
llama_states[k] = v
if index_dict is not None:
for k in llama_states.keys():
for k in llama_states:
index_dict["weight_map"][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
torch.save(llama_states, os.path.join(tgt, filename))