From fb5c656598ff35c58ab362b52f2fbcfb9e290356 Mon Sep 17 00:00:00 2001 From: tifa Date: Thu, 11 Apr 2024 14:35:01 +0800 Subject: [PATCH] fix the issue of safetensors conversion LLama error --- tools/convert2llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/convert2llama.py b/tools/convert2llama.py index 1462924..de65b58 100644 --- a/tools/convert2llama.py +++ b/tools/convert2llama.py @@ -60,7 +60,7 @@ def convert(src, tgt): head_dim = config.hidden_size // config.num_attention_heads num_key_value_groups = config.num_attention_heads \ - // config.num_key_value_heads + // config.num_key_value_heads # load index json file index_file = 'pytorch_model.bin.index.json' @@ -138,7 +138,11 @@ def convert(src, tgt): index_dict['weight_map'][k] = filename print(f'Saving to {os.path.join(tgt, filename)}...', flush=True) - torch.save(llama_states, os.path.join(tgt, filename)) + if filename.endswith('.safetensors'): + from safetensors.torch import save_file + save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"}) + else: + torch.save(llama_states, os.path.join(tgt, filename)) del states print('Saving config and tokenizer...', flush=True)