mirror of https://github.com/InternLM/InternLM
Use tempfile for convert2hf.py
parent
ed04c7edb0
commit
efec8ab87e
|
@ -4,6 +4,7 @@ import os
|
|||
import random
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from modeling_internlm import InternLMConfig, InternLMForCausalLM
|
||||
|
@ -15,10 +16,8 @@ NUM_SHARDS = {
|
|||
|
||||
|
||||
def convert2hf(model_config, states_tp_pps):
|
||||
folder = f"/dev/shm/wait_to_upload_weight_tmp_{random.random()}/"
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
states = merge_pp(states_tp_pps)[0]
|
||||
|
||||
if "embedding.word_embeddings.weight" in states:
|
||||
|
@ -91,9 +90,6 @@ def convert2hf(model_config, states_tp_pps):
|
|||
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
del model.config._name_or_path
|
||||
|
||||
finally:
|
||||
shutil.rmtree(folder)
|
||||
|
||||
return config, model
|
||||
|
||||
|
Loading…
Reference in New Issue