mirror of https://github.com/InternLM/InternLM
fix small bugs of convert2hf;remove low_cpu_usage
parent
987602d4df
commit
2f56b6a00b
|
@ -4,10 +4,10 @@
|
||||||
|
|
||||||
## 权重转换
|
## 权重转换
|
||||||
|
|
||||||
`convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。
|
`convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。在根目录下执行:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
|
python tools/transformers/convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
|
||||||
```
|
```
|
||||||
|
|
||||||
然后可以使用 `from_pretrained` 接口加载:
|
然后可以使用 `from_pretrained` 接口加载:
|
||||||
|
|
|
@ -85,7 +85,7 @@ def convert2hf(model_config, states_tp_pps):
|
||||||
config.save_pretrained(folder)
|
config.save_pretrained(folder)
|
||||||
torch.save(current_states, os.path.join(folder, "pytorch_model.bin"))
|
torch.save(current_states, os.path.join(folder, "pytorch_model.bin"))
|
||||||
|
|
||||||
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16)
|
||||||
del model.config._name_or_path
|
del model.config._name_or_path
|
||||||
|
|
||||||
return config, model
|
return config, model
|
||||||
|
|
|
@ -31,7 +31,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.generation.streamers import BaseStreamer
|
from transformers.generation.streamers import BaseStreamer
|
||||||
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_internlm import InternLMConfig
|
from configuration_internlm import InternLMConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
Loading…
Reference in New Issue