mirror of https://github.com/InternLM/InternLM
[Tool]: Fix the issue of safetensors conversion LLama error (#732)
parent
c4108d3431
commit
2db5604288
|
@ -60,7 +60,7 @@ def convert(src, tgt):
|
|||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
num_key_value_groups = config.num_attention_heads \
|
||||
// config.num_key_value_heads
|
||||
// config.num_key_value_heads
|
||||
|
||||
# load index json file
|
||||
index_file = 'pytorch_model.bin.index.json'
|
||||
|
@ -138,7 +138,11 @@ def convert(src, tgt):
|
|||
index_dict['weight_map'][k] = filename
|
||||
|
||||
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
|
||||
torch.save(llama_states, os.path.join(tgt, filename))
|
||||
if filename.endswith('.safetensors'):
|
||||
from safetensors.torch import save_file
|
||||
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(llama_states, os.path.join(tgt, filename))
|
||||
del states
|
||||
|
||||
print('Saving config and tokenizer...', flush=True)
|
||||
|
|
Loading…
Reference in New Issue