Browse Source

Save PrefixEncoder params only

pull/518/head
rainatam 2 years ago
parent
commit
cbb9f44e30
  1. 12
      ptuning/README.md
  2. 3
      ptuning/evaluate.sh
  3. 22
      ptuning/main.py
  4. 3824
      ptuning/trainer.py
  5. 2
      ptuning/trainer_seq2seq.py

12
ptuning/README.md

@ -39,6 +39,18 @@ bash train.sh
```shell ```shell
bash evaluate.sh bash evaluate.sh
``` ```
**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,在推理时需要同时载入原 ChatGLM-6B 模型以及 PrefixEncoder 的 Checkpoint,因此需要指定参数(已更新 `evaluate.sh`) :
```shell
--model_name_or_path THUDM/chatglm-6b
--ptuning_checkpoint $CHECKPOINT_PATH
```
仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`
```shell
--model_name_or_path $CHECKPOINT_PATH
```
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt` `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`

3
ptuning/evaluate.sh

@ -9,7 +9,8 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \
--overwrite_cache \ --overwrite_cache \
--prompt_column content \ --prompt_column content \
--response_column summary \ --response_column summary \
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \ --model_name_or_path THUDM/chatglm-6b \
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
--output_dir ./output/$CHECKPOINT \ --output_dir ./output/$CHECKPOINT \
--overwrite_output_dir \ --overwrite_output_dir \
--max_source_length 64 \ --max_source_length 64 \

22
ptuning/main.py

@ -28,6 +28,7 @@ from datasets import load_dataset
import jieba import jieba
from rouge_chinese import Rouge from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch
import transformers import transformers
from transformers import ( from transformers import (
@ -110,13 +111,28 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) if model_args.ptuning_checkpoint is not None:
# Evaluation
# Loading extra state dict of prefix encoder
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit") print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit) model = model.quantize(model_args.quantization_bit)
model = model.half() if model_args.pre_seq_len is not None:
model.transformer.prefix_encoder.float() # P-tuning v2
model = model.half()
model.transformer.prefix_encoder.float()
else:
# Finetune
model = model.float()
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

3824
ptuning/trainer.py

File diff suppressed because it is too large Load Diff

2
ptuning/trainer_seq2seq.py

@ -19,7 +19,7 @@ from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from trainer import Trainer
from transformers.trainer_utils import PredictionOutput from transformers.trainer_utils import PredictionOutput
from transformers.utils import logging from transformers.utils import logging

Loading…
Cancel
Save