mirror of https://github.com/InternLM/InternLM
lint check
parent
3d609d8e38
commit
d99ba98663
|
@ -1,16 +1,25 @@
|
||||||
# Copyright (c) InternLM. All rights reserved.
|
# Copyright (c) InternLM. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig, LlamaTokenizer, LlamaConfig
|
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
|
||||||
|
|
||||||
|
|
||||||
def save_conifg(config, tgt):
|
def save_conifg(config, tgt):
|
||||||
config_dict = config.to_dict()
|
config_dict = config.to_dict()
|
||||||
unnecessary_keys = ["_name_or_path", "auto_map", "transformers_version", "model_type", "architectures", "tokenizer_class", "attn_implementation"]
|
unnecessary_keys = [
|
||||||
|
"_name_or_path",
|
||||||
|
"auto_map",
|
||||||
|
"transformers_version",
|
||||||
|
"model_type",
|
||||||
|
"architectures",
|
||||||
|
"tokenizer_class",
|
||||||
|
"attn_implementation",
|
||||||
|
]
|
||||||
for k in unnecessary_keys:
|
for k in unnecessary_keys:
|
||||||
config_dict.pop(k, None)
|
config_dict.pop(k, None)
|
||||||
config_dict["attention_bias"] = config_dict.pop("bias")
|
config_dict["attention_bias"] = config_dict.pop("bias")
|
||||||
|
@ -29,7 +38,6 @@ def convert(src, tgt):
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||||
|
|
||||||
|
|
||||||
# load index json file
|
# load index json file
|
||||||
index_file = os.path.join(src, "pytorch_model.bin.index.json")
|
index_file = os.path.join(src, "pytorch_model.bin.index.json")
|
||||||
|
@ -54,9 +62,7 @@ def convert(src, tgt):
|
||||||
gs=2 + num_key_value_groups,
|
gs=2 + num_key_value_groups,
|
||||||
d=head_dim,
|
d=head_dim,
|
||||||
)
|
)
|
||||||
wq, wk, wv = torch.split(
|
wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1)
|
||||||
v, [num_key_value_groups, 1, 1], dim=1
|
|
||||||
)
|
|
||||||
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
|
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
|
||||||
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
|
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
|
||||||
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
|
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
|
||||||
|
@ -94,7 +100,7 @@ def convert(src, tgt):
|
||||||
llama_states[k] = v
|
llama_states[k] = v
|
||||||
|
|
||||||
if index_dict is not None:
|
if index_dict is not None:
|
||||||
for k in llama_states.keys():
|
for k in llama_states:
|
||||||
index_dict["weight_map"][k] = filename
|
index_dict["weight_map"][k] = filename
|
||||||
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
||||||
torch.save(llama_states, os.path.join(tgt, filename))
|
torch.save(llama_states, os.path.join(tgt, filename))
|
||||||
|
|
Loading…
Reference in New Issue