mirror of https://github.com/THUDM/ChatGLM-6B
Save PrefixEncoder params only
parent
4478546058
commit
cbb9f44e30
|
@ -39,6 +39,18 @@ bash train.sh
|
||||||
```shell
|
```shell
|
||||||
bash evaluate.sh
|
bash evaluate.sh
|
||||||
```
|
```
|
||||||
|
**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,在推理时需要同时载入原 ChatGLM-6B 模型以及 PrefixEncoder 的 Checkpoint,因此需要指定参数(已更新 `evaluate.sh`) :
|
||||||
|
|
||||||
|
```shell
|
||||||
|
--model_name_or_path THUDM/chatglm-6b
|
||||||
|
--ptuning_checkpoint $CHECKPOINT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
--model_name_or_path $CHECKPOINT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
|
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
|
||||||
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
|
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
|
||||||
|
|
|
@ -9,7 +9,8 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--prompt_column content \
|
--prompt_column content \
|
||||||
--response_column summary \
|
--response_column summary \
|
||||||
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
|
--model_name_or_path THUDM/chatglm-6b \
|
||||||
|
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
|
||||||
--output_dir ./output/$CHECKPOINT \
|
--output_dir ./output/$CHECKPOINT \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--max_source_length 64 \
|
--max_source_length 64 \
|
||||||
|
|
|
@ -28,6 +28,7 @@ from datasets import load_dataset
|
||||||
import jieba
|
import jieba
|
||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
import torch
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
@ -110,13 +111,28 @@ def main():
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
if model_args.ptuning_checkpoint is not None:
|
||||||
|
# Evaluation
|
||||||
|
# Loading extra state dict of prefix encoder
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||||
|
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
|
||||||
|
new_prefix_state_dict = {}
|
||||||
|
for k, v in prefix_state_dict.items():
|
||||||
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
|
else:
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
print(f"Quantized to {model_args.quantization_bit} bit")
|
print(f"Quantized to {model_args.quantization_bit} bit")
|
||||||
model = model.quantize(model_args.quantization_bit)
|
model = model.quantize(model_args.quantization_bit)
|
||||||
model = model.half()
|
if model_args.pre_seq_len is not None:
|
||||||
model.transformer.prefix_encoder.float()
|
# P-tuning v2
|
||||||
|
model = model.half()
|
||||||
|
model.transformer.prefix_encoder.float()
|
||||||
|
else:
|
||||||
|
# Finetune
|
||||||
|
model = model.float()
|
||||||
|
|
||||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -19,7 +19,7 @@ from torch import nn
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from trainer import Trainer
|
||||||
from transformers.trainer_utils import PredictionOutput
|
from transformers.trainer_utils import PredictionOutput
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue