Add P-Tuning v2

pull/236/merge
duzx16 2023-03-31 10:43:55 +08:00
parent 323ce7c865
commit 968a30672a
7 changed files with 982 additions and 10 deletions

View File

@ -10,20 +10,12 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
*Read this in [English](README_en.md).*
## 更新信息
**[2023/03/31]** 增加基于 P-Tuning-v2 的微调实现,最低只需 8GB 显存即可进行模型微调。详见[模型微调](ptuning/README.md)。
**[2023/03/23]** 增加API部署感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT)。增加Embedding量化模型[ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)。增加对基于Apple Silicon的Mac上GPU加速的支持。
**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)
## 友情链接
以下是部分基于本仓库开发的开源项目:
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调
以下是部分针对本项目的教程/文档:
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
如果你有其他好的项目/教程的话欢迎参照上述格式添加到README中并提出 [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
## 使用方式
### 硬件需求
@ -171,6 +163,9 @@ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).hal
```
即可使用在 Mac 上使用 GPU 加速模型推理。
## 模型微调
详见 [ptuning/README.md](ptuning/README.md)。
## ChatGLM-6B 示例
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
@ -259,6 +254,16 @@ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).hal
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
## 友情链接
以下是部分基于本仓库开发的开源项目:
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调
以下是部分针对本项目的教程/文档:
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
如果你有其他好的项目/教程的话欢迎参照上述格式添加到README中并提出 [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
## 引用
如果你觉得我们的工作有帮助的话,请考虑引用下列论文

70
ptuning/README.md Normal file
View File

@ -0,0 +1,70 @@
# ChatGLM-6B-PT
本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2将需要微调的参数量减少到原来的0.1%再通过模型量化、Gradient Checkpoint等方法最低只需要 8GB 显存即可运行。
下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。
## 软件依赖
除 ChatGLM-6B 的依赖之外,还需要按照以下依赖
```
pip install rouge_chinese nltk jieba datasets
```
## 使用方法
### 下载数据集
ADGEN 数据集任务为根据输入content生成一段广告词summary
```json
{
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
```
从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud]() 下载处理好的 ADGEN数据集将解压后的 `AdvertiseGen` 目录放到本目录下。
### 训练
运行以下指令进行训练:
```shell
bash train.sh
```
`train.sh` 中的`PRE_SEQ_LEN`和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。
### 推理
将`evaluate.sh`中的`CHECKPOINT`更改为训练时保存的checkpoint名称运行以下指令进行模型推理和评测
```shell
bash evaluate.sh
```
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`
### 例子
#### 示例1
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
* 微调前Output: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
* 微调后Output: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
#### 示例2
* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
* 微调前Output: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
* 微调后Output: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
## 使用自己的数据集
修改 `train.sh``evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的json格式数据集路径并将`prompt_column`和`response_column`改为json文件中输入文本和输出文本对应的key。
## 引用
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
pages={61--68},
year={2022}
}
```

217
ptuning/arguments.py Normal file
View File

@ -0,0 +1,217 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
)
},
)
quantization_bit: Optional[int] = field(
default=None
)
pre_seq_len: Optional[int] = field(
default=None
)
prefix_projection: bool = field(
default=False
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
prompt_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
response_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={
"help": (
"An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
)
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": (
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
)
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": (
"The token to force as the first generated token after the decoder_start_token_id."
"Useful for multilingual models like mBART where the first generated token"
"needs to be the target language token (Usually it is the target language token)"
)
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length

20
ptuning/evaluate.sh Normal file
View File

@ -0,0 +1,20 @@
PRE_SEQ_LEN=8
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
STEP=3000
CUDA_VISIBLE_DEVICES=0 python3 main.py \
--do_predict \
--test_file AdvertiseGen/dev.json \
--overwrite_cache \
--prompt_column content \
--response_column summary \
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
--output_dir ./output/$CHECKPOINT \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_eval_batch_size 1 \
--predict_with_generate \
--max_predict_samples 10 \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4

389
ptuning/main.py Normal file
View File

@ -0,0 +1,389 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import logging
import os
import sys
import json
import numpy as np
from datasets import load_dataset
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from trainer_seq2seq import Seq2SeqTrainer
from arguments import ModelArguments, DataTrainingArguments
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Load dataset
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, revision=True, trust_remote_code=True)
model = model.half()
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)
model.transformer.prefix_encoder.float()
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = raw_datasets["train"].column_names
elif training_args.do_eval:
column_names = raw_datasets["validation"].column_names
elif training_args.do_predict:
column_names = raw_datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
# Get the column names for input/target.
prompt_column = data_args.prompt_column
response_column = data_args.response_column
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
def preprocess_function_eval(examples):
inputs, targets = [], []
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
inputs.append(examples[prompt_column][i])
targets.append(examples[response_column][i])
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True)
labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
if data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def preprocess_function_train(examples):
max_seq_length = data_args.max_source_length + data_args.max_target_length
model_inputs = {
"input_ids": [],
"labels": [],
}
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
prompt, answer = examples[prompt_column][i], examples[response_column][i]
prompt = prefix + prompt
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(a_ids) > data_args.max_source_length - 1:
a_ids = a_ids[: data_args.max_source_length - 1]
if len(b_ids) > data_args.max_target_length - 2:
b_ids = b_ids[: data_args.max_target_length - 2]
input_ids = a_ids + [150001, 150004] + b_ids + [150005]
context_length = input_ids.index(150004)
mask_position = context_length - 1
labels = [-100] * context_length + input_ids[mask_position+1:]
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def print_dataset_example(example):
print("input_ids",example["input_ids"])
print("inputs", tokenizer.decode(example["input_ids"]))
print("label_ids", example["labels"])
print("labels", tokenizer.decode(example["labels"]))
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function_train,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
print_dataset_example(train_dataset[0])
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
print_dataset_example(eval_dataset[0])
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
)
print_dataset_example(predict_dataset[0])
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=None,
)
# Metric
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
score_dict = {
"rouge-1": [],
"rouge-2": [],
"rouge-l": [],
"bleu-4": []
}
for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
rouge = Rouge()
scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred))
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
for k, v in score_dict.items():
score_dict[k] = float(np.mean(v))
return score_dict
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
# elif last_checkpoint is not None:
# checkpoint = last_checkpoint
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Predict ***")
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95)
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate:
predictions = tokenizer.batch_decode(
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
predictions = [pred.strip() for pred in predictions]
labels = tokenizer.batch_decode(
predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
labels = [label.strip() for label in labels]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
with open(output_prediction_file, "w") as writer:
for p, l in zip(predictions, labels):
writer.write(json.dumps({"labels": l, "predict": p}, ensure_ascii=False))
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

26
ptuning/train.sh Normal file
View File

@ -0,0 +1,26 @@
PRE_SEQ_LEN=8
LR=1e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /mnt/vepfs/workspace/zxdu/chatglm_6b \
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR-dev \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4

245
ptuning/trainer_seq2seq.py Normal file
View File

@ -0,0 +1,245 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from transformers.trainer_utils import PredictionOutput
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init `compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
self,
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test",
**gen_kwargs
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in `evaluate()`.
Args:
test_dataset (`Dataset`):
Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. Has to implement the method `__len__`
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
<Tip>
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
</Tip>
Returns: *NamedTuple* A namedtuple with the following keys:
- predictions (`np.ndarray`): The predictions on `test_dataset`.
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
if not self.args.predict_with_generate or prediction_loss_only:
return super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
gen_kwargs = self._gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
if "global_attention_mask" in inputs:
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
# prepare generation inputs
# some encoder-decoder models can have varying encoder's and thus
# varying model input names
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
generation_inputs = inputs[self.model.encoder.main_input_name]
else:
generation_inputs = inputs[self.model.main_input_name]
gen_kwargs["input_ids"] = generation_inputs
generated_tokens = self.model.generate(**gen_kwargs)
generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
# in case the batch is shorter than max length, the output should be padded
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
loss = None
if self.args.prediction_loss_only:
return (loss, None, None)
if has_labels:
labels = inputs["labels"]
if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
else:
labels = None
return (loss, generated_tokens, labels)
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor