mirror of https://github.com/THUDM/ChatGLM-6B
Save PrefixEncoder params only
parent
4478546058
commit
cbb9f44e30
|
@ -39,6 +39,18 @@ bash train.sh
|
|||
```shell
|
||||
bash evaluate.sh
|
||||
```
|
||||
**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,在推理时需要同时载入原 ChatGLM-6B 模型以及 PrefixEncoder 的 Checkpoint,因此需要指定参数(已更新 `evaluate.sh`) :
|
||||
|
||||
```shell
|
||||
--model_name_or_path THUDM/chatglm-6b
|
||||
--ptuning_checkpoint $CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:
|
||||
|
||||
```shell
|
||||
--model_name_or_path $CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
|
||||
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
|
||||
|
|
|
@ -9,7 +9,8 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
|||
--overwrite_cache \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
|
||||
--model_name_or_path THUDM/chatglm-6b \
|
||||
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
|
||||
--output_dir ./output/$CHECKPOINT \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 64 \
|
||||
|
|
|
@ -28,6 +28,7 @@ from datasets import load_dataset
|
|||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
|
@ -110,13 +111,28 @@ def main():
|
|||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||
if model_args.ptuning_checkpoint is not None:
|
||||
# Evaluation
|
||||
# Loading extra state dict of prefix encoder
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
print(f"Quantized to {model_args.quantization_bit} bit")
|
||||
model = model.quantize(model_args.quantization_bit)
|
||||
model = model.half()
|
||||
model.transformer.prefix_encoder.float()
|
||||
if model_args.pre_seq_len is not None:
|
||||
# P-tuning v2
|
||||
model = model.half()
|
||||
model.transformer.prefix_encoder.float()
|
||||
else:
|
||||
# Finetune
|
||||
model = model.float()
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -19,7 +19,7 @@ from torch import nn
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
from trainer import Trainer
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
from transformers.utils import logging
|
||||
|
||||
|
|
Loading…
Reference in New Issue