[Tool]: Fix the issue of safetensors conversion LLama error (#732)

pull/735/head
tifa 2024-04-11 14:54:55 +08:00 committed by GitHub
parent c4108d3431
commit 2db5604288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 2 deletions

View File

@ -60,7 +60,7 @@ def convert(src, tgt):
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
num_key_value_groups = config.num_attention_heads \ num_key_value_groups = config.num_attention_heads \
// config.num_key_value_heads // config.num_key_value_heads
# load index json file # load index json file
index_file = 'pytorch_model.bin.index.json' index_file = 'pytorch_model.bin.index.json'
@ -138,7 +138,11 @@ def convert(src, tgt):
index_dict['weight_map'][k] = filename index_dict['weight_map'][k] = filename
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True) print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
torch.save(llama_states, os.path.join(tgt, filename)) if filename.endswith('.safetensors'):
from safetensors.torch import save_file
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
else:
torch.save(llama_states, os.path.join(tgt, filename))
del states del states
print('Saving config and tokenizer...', flush=True) print('Saving config and tokenizer...', flush=True)