Use tempfile for convert2hf.py (#23)

Fix https://github.com/InternLM/InternLM/issues/50
pull/97/head^2
x54-729 2023-07-17 21:08:10 +08:00 committed by GitHub
parent 59f4727675
commit 0c1060435d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 24 additions and 25 deletions

View File

@ -145,10 +145,10 @@ streamlit run web_demo.py
### 转换为 Transformers 格式使用
通过 InternLM 进行训练的模型可以很轻松地转换为 HuggingFace Transformers 格式,方便与社区各种开源项目无缝对接。借助 `tools/convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式
通过 InternLM 进行训练的模型可以很轻松地转换为 HuggingFace Transformers 格式,方便与社区各种开源项目无缝对接。借助 `tools/transformers/convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式
```bash
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
python tools/transformers/convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer ./tools/V7_sft.model
```
转换之后可以通过以下的代码加载为 transformers

View File

@ -152,10 +152,10 @@ Please refer to [Usage Tutorial](./doc/en/usage.md) to start InternLM installati
### Convert to Transformers Format
The model trained by InternLM can be easily converted to HuggingFace Transformers format, which is convenient for seamless docking with various open source projects in the community. With the help of `tools/convert2hf.py`, the weights saved during training can be converted into transformers format with one command
The model trained by InternLM can be easily converted to HuggingFace Transformers format, which is convenient for seamless docking with various open source projects in the community. With the help of `tools/transformers/convert2hf.py`, the weights saved during training can be converted into transformers format with one command
```bash
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
python tools/transformers/convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer ./tools/V7_sft.model
```
After conversion, it can be loaded as transformers by the following code

View File

@ -8,18 +8,17 @@
## 权重转换
`convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。
`convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。在仓库根目录运行以下命令:
```bash
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer ../v7_sft.model
python tools/transformers/convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer ./tools/V7_sft.model
```
然后可以使用 `from_pretrained` 接口加载:
```python
from modeling_internlm import InternLMForCausalLM
model = InternForCausalLM.from_pretrained("hf_ckpt/")
>>> from transformers import AutoTokenizer, AutoModel
>>> model = AutoModel.from_pretrained("hf_ckpt/", trust_remote_code=True).cuda()
```

View File

@ -7,18 +7,17 @@ This folder contains the `InternLM` model in transformers format.
## Weight Conversion
`convert2hf.py` can convert saved training weights into the transformers format with a single command.
`convert2hf.py` can convert saved training weights into the transformers format with a single command. Execute the command in the root directory of repository:
```bash
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer ../v7_sft.model
python tools/transformers/convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer ./tools/V7_sft.model
```
Then, you can load it using the `from_pretrained` interface:
```python
from modeling_internlm import InternLMForCausalLM
model = InternForCausalLM.from_pretrained("hf_ckpt/")
>>> from transformers import AutoTokenizer, AutoModel
>>> model = AutoModel.from_pretrained("hf_ckpt/", trust_remote_code=True).cuda()
```
`intern_moss_example.py` demonstrates an example of how to use LoRA for fine-tuning on the `fnlp/moss-moon-002-sft` dataset.
`intern_moss_example.py` demonstrates an example of how to use LoRA for fine-tuning on the `fnlp/moss-moon-002-sft` dataset.

View File

@ -1,9 +1,9 @@
import argparse
import math
import json
import os
import random
import re
import shutil
import tempfile
import torch
from modeling_internlm import InternLMConfig, InternLMForCausalLM
@ -15,10 +15,8 @@ NUM_SHARDS = {
def convert2hf(model_config, states_tp_pps):
folder = f"/dev/shm/wait_to_upload_weight_tmp_{random.random()}/"
os.makedirs(folder, exist_ok=True)
try:
with tempfile.TemporaryDirectory() as folder:
states = merge_pp(states_tp_pps)[0]
if "embedding.word_embeddings.weight" in states:
@ -88,12 +86,9 @@ def convert2hf(model_config, states_tp_pps):
config.save_pretrained(folder)
torch.save(current_states, os.path.join(folder, "pytorch_model.bin"))
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16)
del model.config._name_or_path
finally:
shutil.rmtree(folder)
return config, model
@ -169,6 +164,12 @@ if __name__ == "__main__":
os.makedirs(target_folder, exist_ok=True)
model.save_pretrained(target_folder, max_shard_size="20GB")
# TODO There should be a better way to add this.
with open(os.path.join(target_folder, "config.json")) as fp:
config_dict = json.load(fp)
config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMModel"
with open(os.path.join(target_folder, "config.json"), "w") as fp:
json.dump(config_dict, fp, indent=2)
tokenizer = InternLMTokenizer(args.tokenizer)
tokenizer.save_pretrained(target_folder)

View File

@ -31,7 +31,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_internlm import InternLMConfig
from configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__)