mirror of https://github.com/InternLM/InternLM
[Tool]: Update tools/convert2llama.py to support `safetensors` format (#730)
parent
861327b572
commit
c4108d3431
|
@ -9,6 +9,28 @@ from tqdm import tqdm
|
||||||
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
|
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def weight_load(fp, **kwargs):
|
||||||
|
"""Load weights from a file."""
|
||||||
|
is_safetensors = kwargs.pop('is_safetensors', False)
|
||||||
|
|
||||||
|
if is_safetensors:
|
||||||
|
try:
|
||||||
|
from safetensors import safe_open
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
'Before loading ckpts in the `safetensors` format, '
|
||||||
|
'please install the `safetensors` package first.')
|
||||||
|
|
||||||
|
model = safe_open(fp, framework='pt')
|
||||||
|
state_dict = {}
|
||||||
|
for k in model.keys():
|
||||||
|
state_dict[k] = model.get_tensor(k)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
else:
|
||||||
|
return torch.load(fp, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def save_conifg(config, tgt):
|
def save_conifg(config, tgt):
|
||||||
config_dict = config.to_dict()
|
config_dict = config.to_dict()
|
||||||
unnecessary_keys = [
|
unnecessary_keys = [
|
||||||
|
@ -41,9 +63,15 @@ def convert(src, tgt):
|
||||||
// config.num_key_value_heads
|
// config.num_key_value_heads
|
||||||
|
|
||||||
# load index json file
|
# load index json file
|
||||||
index_file = os.path.join(src, 'pytorch_model.bin.index.json')
|
index_file = 'pytorch_model.bin.index.json'
|
||||||
if os.path.exists(index_file):
|
if os.path.exists(os.path.join(src, index_file)):
|
||||||
with open(index_file) as fp:
|
with open(os.path.join(src, index_file)) as fp:
|
||||||
|
index_dict = json.load(fp)
|
||||||
|
index_dict['weight_map'] = {}
|
||||||
|
else:
|
||||||
|
index_file = 'model.safetensors.index.json'
|
||||||
|
if os.path.exists(os.path.join(src, index_file)):
|
||||||
|
with open(os.path.join(src, index_file)) as fp:
|
||||||
index_dict = json.load(fp)
|
index_dict = json.load(fp)
|
||||||
index_dict['weight_map'] = {}
|
index_dict['weight_map'] = {}
|
||||||
else:
|
else:
|
||||||
|
@ -51,9 +79,13 @@ def convert(src, tgt):
|
||||||
|
|
||||||
os.makedirs(tgt, exist_ok=True)
|
os.makedirs(tgt, exist_ok=True)
|
||||||
for filename in tqdm(os.listdir(src)):
|
for filename in tqdm(os.listdir(src)):
|
||||||
if not filename.endswith('.bin'):
|
if not any(filename.endswith(ext) for ext in ('.bin', '.safetensors')):
|
||||||
continue
|
continue
|
||||||
states = torch.load(os.path.join(src, filename))
|
|
||||||
|
print(f'Loading {os.path.join(src, filename)}...', flush=True)
|
||||||
|
states = weight_load(os.path.join(src, filename),
|
||||||
|
is_safetensors=filename.endswith('.safetensors'))
|
||||||
|
|
||||||
llama_states = {}
|
llama_states = {}
|
||||||
for k, v in states.copy().items():
|
for k, v in states.copy().items():
|
||||||
if 'wqkv' in k:
|
if 'wqkv' in k:
|
||||||
|
@ -104,15 +136,15 @@ def convert(src, tgt):
|
||||||
if index_dict is not None:
|
if index_dict is not None:
|
||||||
for k in llama_states:
|
for k in llama_states:
|
||||||
index_dict['weight_map'][k] = filename
|
index_dict['weight_map'][k] = filename
|
||||||
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
|
||||||
|
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
|
||||||
torch.save(llama_states, os.path.join(tgt, filename))
|
torch.save(llama_states, os.path.join(tgt, filename))
|
||||||
del states
|
del states
|
||||||
|
|
||||||
print('Saving config and tokenizer...')
|
print('Saving config and tokenizer...', flush=True)
|
||||||
# index.json
|
# index.json
|
||||||
if index_dict is not None:
|
if index_dict is not None:
|
||||||
with open(os.path.join(tgt, 'pytorch_model.bin.index.json'),
|
with open(os.path.join(tgt, index_file), 'w') as fp:
|
||||||
'w') as fp:
|
|
||||||
json.dump(index_dict, fp, indent=2)
|
json.dump(index_dict, fp, indent=2)
|
||||||
# tokenizer
|
# tokenizer
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(src)
|
tokenizer = LlamaTokenizer.from_pretrained(src)
|
||||||
|
@ -120,7 +152,8 @@ def convert(src, tgt):
|
||||||
tokenizer.save_pretrained(tgt)
|
tokenizer.save_pretrained(tgt)
|
||||||
# config
|
# config
|
||||||
save_conifg(config, tgt)
|
save_conifg(config, tgt)
|
||||||
print('Done!')
|
|
||||||
|
print('Done!', flush=True)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
Loading…
Reference in New Issue