mirror of https://github.com/InternLM/InternLM
[Tool]: Update tools/convert2llama.py to support `safetensors` format (#730)
parent
861327b572
commit
c4108d3431
|
@ -9,6 +9,28 @@ from tqdm import tqdm
|
|||
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer
|
||||
|
||||
|
||||
def weight_load(fp, **kwargs):
|
||||
"""Load weights from a file."""
|
||||
is_safetensors = kwargs.pop('is_safetensors', False)
|
||||
|
||||
if is_safetensors:
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Before loading ckpts in the `safetensors` format, '
|
||||
'please install the `safetensors` package first.')
|
||||
|
||||
model = safe_open(fp, framework='pt')
|
||||
state_dict = {}
|
||||
for k in model.keys():
|
||||
state_dict[k] = model.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
else:
|
||||
return torch.load(fp, **kwargs)
|
||||
|
||||
|
||||
def save_conifg(config, tgt):
|
||||
config_dict = config.to_dict()
|
||||
unnecessary_keys = [
|
||||
|
@ -41,19 +63,29 @@ def convert(src, tgt):
|
|||
// config.num_key_value_heads
|
||||
|
||||
# load index json file
|
||||
index_file = os.path.join(src, 'pytorch_model.bin.index.json')
|
||||
if os.path.exists(index_file):
|
||||
with open(index_file) as fp:
|
||||
index_file = 'pytorch_model.bin.index.json'
|
||||
if os.path.exists(os.path.join(src, index_file)):
|
||||
with open(os.path.join(src, index_file)) as fp:
|
||||
index_dict = json.load(fp)
|
||||
index_dict['weight_map'] = {}
|
||||
else:
|
||||
index_dict = None
|
||||
index_file = 'model.safetensors.index.json'
|
||||
if os.path.exists(os.path.join(src, index_file)):
|
||||
with open(os.path.join(src, index_file)) as fp:
|
||||
index_dict = json.load(fp)
|
||||
index_dict['weight_map'] = {}
|
||||
else:
|
||||
index_dict = None
|
||||
|
||||
os.makedirs(tgt, exist_ok=True)
|
||||
for filename in tqdm(os.listdir(src)):
|
||||
if not filename.endswith('.bin'):
|
||||
if not any(filename.endswith(ext) for ext in ('.bin', '.safetensors')):
|
||||
continue
|
||||
states = torch.load(os.path.join(src, filename))
|
||||
|
||||
print(f'Loading {os.path.join(src, filename)}...', flush=True)
|
||||
states = weight_load(os.path.join(src, filename),
|
||||
is_safetensors=filename.endswith('.safetensors'))
|
||||
|
||||
llama_states = {}
|
||||
for k, v in states.copy().items():
|
||||
if 'wqkv' in k:
|
||||
|
@ -104,15 +136,15 @@ def convert(src, tgt):
|
|||
if index_dict is not None:
|
||||
for k in llama_states:
|
||||
index_dict['weight_map'][k] = filename
|
||||
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
|
||||
|
||||
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
|
||||
torch.save(llama_states, os.path.join(tgt, filename))
|
||||
del states
|
||||
|
||||
print('Saving config and tokenizer...')
|
||||
print('Saving config and tokenizer...', flush=True)
|
||||
# index.json
|
||||
if index_dict is not None:
|
||||
with open(os.path.join(tgt, 'pytorch_model.bin.index.json'),
|
||||
'w') as fp:
|
||||
with open(os.path.join(tgt, index_file), 'w') as fp:
|
||||
json.dump(index_dict, fp, indent=2)
|
||||
# tokenizer
|
||||
tokenizer = LlamaTokenizer.from_pretrained(src)
|
||||
|
@ -120,7 +152,8 @@ def convert(src, tgt):
|
|||
tokenizer.save_pretrained(tgt)
|
||||
# config
|
||||
save_conifg(config, tgt)
|
||||
print('Done!')
|
||||
|
||||
print('Done!', flush=True)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
Loading…
Reference in New Issue