From 2f56b6a00bf7cbc318b5faf7e9799ae6871ea114 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 10 Jul 2023 17:06:09 +0800 Subject: [PATCH] fix small bugs of convert2hf;remove low_cpu_usage --- tools/transformers/README.md | 4 ++-- tools/transformers/convert2hf.py | 2 +- tools/transformers/modeling_internlm.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/transformers/README.md b/tools/transformers/README.md index bd16da1..879b5f5 100644 --- a/tools/transformers/README.md +++ b/tools/transformers/README.md @@ -4,10 +4,10 @@ ## 权重转换 -`convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。 +`convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。在根目录下执行: ```bash -python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model +python tools/transformers/convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model ``` 然后可以使用 `from_pretrained` 接口加载: diff --git a/tools/transformers/convert2hf.py b/tools/transformers/convert2hf.py index 3b38609..7e86f6f 100644 --- a/tools/transformers/convert2hf.py +++ b/tools/transformers/convert2hf.py @@ -85,7 +85,7 @@ def convert2hf(model_config, states_tp_pps): config.save_pretrained(folder) torch.save(current_states, os.path.join(folder, "pytorch_model.bin")) - model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True) + model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16) del model.config._name_or_path return config, model diff --git a/tools/transformers/modeling_internlm.py b/tools/transformers/modeling_internlm.py index 21eb7f5..df1e19f 100644 --- a/tools/transformers/modeling_internlm.py +++ b/tools/transformers/modeling_internlm.py @@ -31,7 +31,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu from transformers.modeling_utils import PreTrainedModel from transformers.generation.streamers import BaseStreamer from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_internlm import InternLMConfig +from configuration_internlm import InternLMConfig logger = logging.get_logger(__name__)