[Tool]: Fix the issue of safetensors conversion LLama error (#732)

pull/735/head
tifa 2024-04-11 14:54:55 +08:00 committed by GitHub
parent c4108d3431
commit 2db5604288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 2 deletions

View File

@ -138,6 +138,10 @@ def convert(src, tgt):
index_dict['weight_map'][k] = filename
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
if filename.endswith('.safetensors'):
from safetensors.torch import save_file
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
else:
torch.save(llama_states, os.path.join(tgt, filename))
del states