mirror of https://github.com/InternLM/InternLM
lint check
parent
3d609d8e38
commit
d99ba98663
|
@ -1,16 +1,25 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, LlamaTokenizer, LlamaConfig
|
||||
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
|
||||
|
||||
|
||||
def save_conifg(config, tgt):
|
||||
config_dict = config.to_dict()
|
||||
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"]
|
||||
unnecessary_keys = [
|
||||
"_name_or_path",
|
||||
"auto_map",
|
||||
"transformers_version",
|
||||
"model_type",
|
||||
"architectures",
|
||||
"tokenizer_class",
|
||||
"attn_implementation",
|
||||
]
|
||||
for k in unnecessary_keys:
|
||||
config_dict.pop(k, None)
|
||||
config_dict["attention_bias"] = config_dict.pop("bias")
|
||||
|
@ -30,7 +39,6 @@ def convert(src, tgt):
|
|||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
|
||||
|
||||
# load index json file
|
||||
index_file = os.path.join(src, "pytorch_model.bin.index.json")
|
||||
if os.path.exists(index_file):
|
||||
|
@ -54,9 +62,7 @@ def convert(src, tgt):
|
|||
gs=2 + num_key_value_groups,
|
||||
d=head_dim,
|
||||
)
|
||||
wq, wk, wv = torch.split(
|
||||
v, [num_key_value_groups, 1, 1], dim=1
|
||||
)
|
||||
wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1)
|
||||
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
|
||||
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
|
||||
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
|
||||
|
@ -94,7 +100,7 @@ def convert(src, tgt):
|
|||
llama_states[k] = v
|
||||
|
||||
if index_dict is not None:
|
||||
for k in llama_states.keys():
|
||||
for k in llama_states:
|
||||
index_dict["weight_map"][k] = filename
|
||||
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
||||
torch.save(llama_states, os.path.join(tgt, filename))
|
||||
|
|
Loading…
Reference in New Issue