fix the issue of safetensors conversion LLama error

pull/732/head
tifa 2024-04-11 14:35:01 +08:00
parent c4108d3431
commit fb5c656598
1 changed files with 6 additions and 2 deletions

View File

@ -138,6 +138,10 @@ def convert(src, tgt):
index_dict['weight_map'][k] = filename
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
if filename.endswith('.safetensors'):
from safetensors.torch import save_file
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
else:
torch.save(llama_states, os.path.join(tgt, filename))
del states