mirror of https://github.com/InternLM/InternLM
Use tempfile for convert2hf.py
parent
ed04c7edb0
commit
efec8ab87e
|
@ -4,6 +4,7 @@ import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modeling_internlm import InternLMConfig, InternLMForCausalLM
|
from modeling_internlm import InternLMConfig, InternLMForCausalLM
|
||||||
|
@ -15,10 +16,8 @@ NUM_SHARDS = {
|
||||||
|
|
||||||
|
|
||||||
def convert2hf(model_config, states_tp_pps):
|
def convert2hf(model_config, states_tp_pps):
|
||||||
folder = f"/dev/shm/wait_to_upload_weight_tmp_{random.random()}/"
|
|
||||||
os.makedirs(folder, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
with tempfile.TemporaryDirectory() as folder:
|
||||||
states = merge_pp(states_tp_pps)[0]
|
states = merge_pp(states_tp_pps)[0]
|
||||||
|
|
||||||
if "embedding.word_embeddings.weight" in states:
|
if "embedding.word_embeddings.weight" in states:
|
||||||
|
@ -91,9 +90,6 @@ def convert2hf(model_config, states_tp_pps):
|
||||||
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||||
del model.config._name_or_path
|
del model.config._name_or_path
|
||||||
|
|
||||||
finally:
|
|
||||||
shutil.rmtree(folder)
|
|
||||||
|
|
||||||
return config, model
|
return config, model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue