mirror of https://github.com/InternLM/InternLM
[Tool]: Fix the issue of safetensors conversion LLama error (#732)
parent
c4108d3431
commit
2db5604288
|
@ -60,7 +60,7 @@ def convert(src, tgt):
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
num_key_value_groups = config.num_attention_heads \
|
num_key_value_groups = config.num_attention_heads \
|
||||||
// config.num_key_value_heads
|
// config.num_key_value_heads
|
||||||
|
|
||||||
# load index json file
|
# load index json file
|
||||||
index_file = 'pytorch_model.bin.index.json'
|
index_file = 'pytorch_model.bin.index.json'
|
||||||
|
@ -138,7 +138,11 @@ def convert(src, tgt):
|
||||||
index_dict['weight_map'][k] = filename
|
index_dict['weight_map'][k] = filename
|
||||||
|
|
||||||
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
|
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
|
||||||
torch.save(llama_states, os.path.join(tgt, filename))
|
if filename.endswith('.safetensors'):
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(llama_states, os.path.join(tgt, filename))
|
||||||
del states
|
del states
|
||||||
|
|
||||||
print('Saving config and tokenizer...', flush=True)
|
print('Saving config and tokenizer...', flush=True)
|
||||||
|
|
Loading…
Reference in New Issue