[Tool]: Update tools/convert2llama.py to support `safetensors` format (#730)

pull/732/head
Yang Gao 2024-04-10 17:06:18 +08:00 committed by GitHub
parent 861327b572
commit c4108d3431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 11 deletions

View File

@ -9,6 +9,28 @@ from tqdm import tqdm
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
def weight_load(fp, **kwargs):
"""Load weights from a file."""
is_safetensors = kwargs.pop('is_safetensors', False)
if is_safetensors:
try:
from safetensors import safe_open
except ImportError:
raise ImportError(
'Before loading ckpts in the `safetensors` format, '
'please install the `safetensors` package first.')
model = safe_open(fp, framework='pt')
state_dict = {}
for k in model.keys():
state_dict[k] = model.get_tensor(k)
return state_dict
else:
return torch.load(fp, **kwargs)
def save_conifg(config, tgt): def save_conifg(config, tgt):
config_dict = config.to_dict() config_dict = config.to_dict()
unnecessary_keys = [ unnecessary_keys = [
@ -41,19 +63,29 @@ def convert(src, tgt):
// config.num_key_value_heads // config.num_key_value_heads
# load index json file # load index json file
index_file = os.path.join(src, 'pytorch_model.bin.index.json') index_file = 'pytorch_model.bin.index.json'
if os.path.exists(index_file): if os.path.exists(os.path.join(src, index_file)):
with open(index_file) as fp: with open(os.path.join(src, index_file)) as fp:
index_dict = json.load(fp) index_dict = json.load(fp)
index_dict['weight_map'] = {} index_dict['weight_map'] = {}
else: else:
index_dict = None index_file = 'model.safetensors.index.json'
if os.path.exists(os.path.join(src, index_file)):
with open(os.path.join(src, index_file)) as fp:
index_dict = json.load(fp)
index_dict['weight_map'] = {}
else:
index_dict = None
os.makedirs(tgt, exist_ok=True) os.makedirs(tgt, exist_ok=True)
for filename in tqdm(os.listdir(src)): for filename in tqdm(os.listdir(src)):
if not filename.endswith('.bin'): if not any(filename.endswith(ext) for ext in ('.bin', '.safetensors')):
continue continue
states = torch.load(os.path.join(src, filename))
print(f'Loading {os.path.join(src, filename)}...', flush=True)
states = weight_load(os.path.join(src, filename),
is_safetensors=filename.endswith('.safetensors'))
llama_states = {} llama_states = {}
for k, v in states.copy().items(): for k, v in states.copy().items():
if 'wqkv' in k: if 'wqkv' in k:
@ -104,15 +136,15 @@ def convert(src, tgt):
if index_dict is not None: if index_dict is not None:
for k in llama_states: for k in llama_states:
index_dict['weight_map'][k] = filename index_dict['weight_map'][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
torch.save(llama_states, os.path.join(tgt, filename)) torch.save(llama_states, os.path.join(tgt, filename))
del states del states
print('Saving config and tokenizer...') print('Saving config and tokenizer...', flush=True)
# index.json # index.json
if index_dict is not None: if index_dict is not None:
with open(os.path.join(tgt, 'pytorch_model.bin.index.json'), with open(os.path.join(tgt, index_file), 'w') as fp:
'w') as fp:
json.dump(index_dict, fp, indent=2) json.dump(index_dict, fp, indent=2)
# tokenizer # tokenizer
tokenizer = LlamaTokenizer.from_pretrained(src) tokenizer = LlamaTokenizer.from_pretrained(src)
@ -120,7 +152,8 @@ def convert(src, tgt):
tokenizer.save_pretrained(tgt) tokenizer.save_pretrained(tgt)
# config # config
save_conifg(config, tgt) save_conifg(config, tgt)
print('Done!')
print('Done!', flush=True)
def parse_args(): def parse_args():