mirror of https://github.com/THUDM/ChatGLM-6B
duzx16
2 years ago
24 changed files with 5678 additions and 95 deletions
@ -0,0 +1,133 @@ |
|||||||
|
# Byte-compiled / optimized / DLL files |
||||||
|
__pycache__/ |
||||||
|
*.py[cod] |
||||||
|
*$py.class |
||||||
|
|
||||||
|
# C extensions |
||||||
|
*.so |
||||||
|
|
||||||
|
# Distribution / packaging |
||||||
|
.Python |
||||||
|
build/ |
||||||
|
develop-eggs/ |
||||||
|
dist/ |
||||||
|
downloads/ |
||||||
|
eggs/ |
||||||
|
.eggs/ |
||||||
|
lib/ |
||||||
|
lib64/ |
||||||
|
parts/ |
||||||
|
sdist/ |
||||||
|
var/ |
||||||
|
wheels/ |
||||||
|
pip-wheel-metadata/ |
||||||
|
share/python-wheels/ |
||||||
|
*.egg-info/ |
||||||
|
.installed.cfg |
||||||
|
*.egg |
||||||
|
MANIFEST |
||||||
|
history/ |
||||||
|
|
||||||
|
# PyInstaller |
||||||
|
# Usually these files are written by a python script from a template |
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
||||||
|
*.manifest |
||||||
|
*.spec |
||||||
|
|
||||||
|
# Installer logs |
||||||
|
pip-log.txt |
||||||
|
pip-delete-this-directory.txt |
||||||
|
|
||||||
|
# Unit test / coverage reports |
||||||
|
htmlcov/ |
||||||
|
.tox/ |
||||||
|
.nox/ |
||||||
|
.coverage |
||||||
|
.coverage.* |
||||||
|
.cache |
||||||
|
nosetests.xml |
||||||
|
coverage.xml |
||||||
|
*.cover |
||||||
|
*.py,cover |
||||||
|
.hypothesis/ |
||||||
|
.pytest_cache/ |
||||||
|
|
||||||
|
# Translations |
||||||
|
*.mo |
||||||
|
*.pot |
||||||
|
|
||||||
|
# Django stuff: |
||||||
|
*.log |
||||||
|
local_settings.py |
||||||
|
db.sqlite3 |
||||||
|
db.sqlite3-journal |
||||||
|
|
||||||
|
# Flask stuff: |
||||||
|
instance/ |
||||||
|
.webassets-cache |
||||||
|
|
||||||
|
# Scrapy stuff: |
||||||
|
.scrapy |
||||||
|
|
||||||
|
# Sphinx documentation |
||||||
|
docs/_build/ |
||||||
|
|
||||||
|
# PyBuilder |
||||||
|
target/ |
||||||
|
|
||||||
|
# Jupyter Notebook |
||||||
|
.ipynb_checkpoints |
||||||
|
|
||||||
|
# IPython |
||||||
|
profile_default/ |
||||||
|
ipython_config.py |
||||||
|
|
||||||
|
# pyenv |
||||||
|
.python-version |
||||||
|
|
||||||
|
# pipenv |
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
||||||
|
# install all needed dependencies. |
||||||
|
#Pipfile.lock |
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
||||||
|
__pypackages__/ |
||||||
|
|
||||||
|
# Celery stuff |
||||||
|
celerybeat-schedule |
||||||
|
celerybeat.pid |
||||||
|
|
||||||
|
# SageMath parsed files |
||||||
|
*.sage.py |
||||||
|
|
||||||
|
# Environments |
||||||
|
.env |
||||||
|
.venv |
||||||
|
env/ |
||||||
|
venv/ |
||||||
|
ENV/ |
||||||
|
env.bak/ |
||||||
|
venv.bak/ |
||||||
|
|
||||||
|
# Spyder project settings |
||||||
|
.spyderproject |
||||||
|
.spyproject |
||||||
|
|
||||||
|
# Rope project settings |
||||||
|
.ropeproject |
||||||
|
|
||||||
|
# mkdocs documentation |
||||||
|
/site |
||||||
|
|
||||||
|
# mypy |
||||||
|
.mypy_cache/ |
||||||
|
.dmypy.json |
||||||
|
dmypy.json |
||||||
|
|
||||||
|
# Pyre type checker |
||||||
|
.pyre/ |
||||||
|
|
||||||
|
# Mac system file |
||||||
|
model/ |
@ -0,0 +1,18 @@ |
|||||||
|
# 友情链接 |
||||||
|
|
||||||
|
以下是部分基于本仓库开发的开源项目: |
||||||
|
* [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。 |
||||||
|
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU |
||||||
|
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf) |
||||||
|
* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于本地知识的 ChatGLM 应用,基于LangChain |
||||||
|
* [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。 |
||||||
|
* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能 |
||||||
|
* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署 |
||||||
|
* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。 |
||||||
|
* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题 |
||||||
|
* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能) |
||||||
|
* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM |
||||||
|
|
||||||
|
以下是部分针对本项目的教程/文档: |
||||||
|
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md) |
||||||
|
* [ChatGLM-6B 的部署与微调教程 @ModelWhale平台](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e) |
@ -0,0 +1,248 @@ |
|||||||
|
# ChatGLM-6B-PT |
||||||
|
本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。 |
||||||
|
|
||||||
|
下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 |
||||||
|
|
||||||
|
*Read this in [English](README_en.md).* |
||||||
|
|
||||||
|
## 软件依赖 |
||||||
|
运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖 |
||||||
|
``` |
||||||
|
pip install rouge_chinese nltk jieba datasets |
||||||
|
``` |
||||||
|
## 使用方法 |
||||||
|
|
||||||
|
### 下载数据集 |
||||||
|
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。 |
||||||
|
|
||||||
|
```json |
||||||
|
{ |
||||||
|
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", |
||||||
|
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" |
||||||
|
} |
||||||
|
``` |
||||||
|
|
||||||
|
从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。 |
||||||
|
|
||||||
|
### 训练 |
||||||
|
|
||||||
|
#### P-tuning v2 |
||||||
|
|
||||||
|
运行以下指令进行训练: |
||||||
|
```shell |
||||||
|
bash train.sh |
||||||
|
``` |
||||||
|
`train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。 |
||||||
|
|
||||||
|
在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。 |
||||||
|
|
||||||
|
如果你想要[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。 |
||||||
|
|
||||||
|
#### Finetune |
||||||
|
|
||||||
|
如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令: |
||||||
|
|
||||||
|
```shell |
||||||
|
bash ds_train_finetune.sh |
||||||
|
``` |
||||||
|
|
||||||
|
### 推理 |
||||||
|
|
||||||
|
将 `evaluate.sh` 中的 `CHECKPOINT` 更改为训练时保存的 checkpoint 名称,运行以下指令进行模型推理和评测: |
||||||
|
```shell |
||||||
|
bash evaluate.sh |
||||||
|
``` |
||||||
|
**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定参数(已更新 `evaluate.sh`) : |
||||||
|
|
||||||
|
```shell |
||||||
|
--model_name_or_path THUDM/chatglm-6b |
||||||
|
--ptuning_checkpoint $CHECKPOINT_PATH |
||||||
|
``` |
||||||
|
|
||||||
|
仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`: |
||||||
|
|
||||||
|
```shell |
||||||
|
--model_name_or_path $CHECKPOINT_PATH |
||||||
|
``` |
||||||
|
|
||||||
|
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 |
||||||
|
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。 |
||||||
|
|
||||||
|
### 例子 |
||||||
|
#### 示例1 |
||||||
|
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 |
||||||
|
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 |
||||||
|
* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 |
||||||
|
* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 |
||||||
|
|
||||||
|
#### 示例2 |
||||||
|
|
||||||
|
* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 |
||||||
|
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 |
||||||
|
* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 |
||||||
|
* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 |
||||||
|
|
||||||
|
### 评估结果 |
||||||
|
|
||||||
|
| | Finetune | P-tuning v2 | LoRA | |
||||||
|
| ------------- | ----------- | ----- | ------------- | |
||||||
|
| BLEU-4 | 8.01 | 8.10 | 7.62 | |
||||||
|
| Rouge-1 | 31.23 | 31.12 | 30.60 | |
||||||
|
| Rouge-2 | 7.36 | 7.11 | 6.96 | |
||||||
|
| Rouge-l | 25.08 | 24.97 | 24.80 | |
||||||
|
| Training Loss | 3.00 | 3.74 | 3.32 | |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 实验设置 |
||||||
|
|
||||||
|
``` |
||||||
|
max_source_length=64 |
||||||
|
max_target_length=64 |
||||||
|
max_steps=3000 |
||||||
|
``` |
||||||
|
|
||||||
|
##### P-tuning v2 |
||||||
|
|
||||||
|
``` |
||||||
|
pre_seq_len=128 |
||||||
|
learning_rate=2e-2 |
||||||
|
quantization_bit=4 |
||||||
|
per_device_train_batch_size=16 |
||||||
|
gradient_accumulation_steps=1 |
||||||
|
``` |
||||||
|
|
||||||
|
##### Finetune |
||||||
|
|
||||||
|
``` |
||||||
|
learning_rate=1e-4 |
||||||
|
fp16 |
||||||
|
num_gpus=4 |
||||||
|
per_device_train_batch_size=4 |
||||||
|
gradient_accumulation_steps=1 |
||||||
|
``` |
||||||
|
|
||||||
|
##### LoRA |
||||||
|
|
||||||
|
实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) |
||||||
|
|
||||||
|
``` |
||||||
|
learning_rate=5e-4 |
||||||
|
per_device_train_batch_size=16 |
||||||
|
gradient_accumulation_steps=1 |
||||||
|
``` |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 模型部署 |
||||||
|
首先载入Tokenizer: |
||||||
|
|
||||||
|
```python |
||||||
|
import os |
||||||
|
import torch |
||||||
|
from transformers import AutoConfig, AutoModel, AutoTokenizer |
||||||
|
|
||||||
|
# 载入Tokenizer |
||||||
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
||||||
|
``` |
||||||
|
|
||||||
|
1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数): |
||||||
|
|
||||||
|
```python |
||||||
|
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) |
||||||
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) |
||||||
|
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) |
||||||
|
new_prefix_state_dict = {} |
||||||
|
for k, v in prefix_state_dict.items(): |
||||||
|
if k.startswith("transformer.prefix_encoder."): |
||||||
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v |
||||||
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) |
||||||
|
``` |
||||||
|
注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。 |
||||||
|
|
||||||
|
2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint: |
||||||
|
|
||||||
|
```python |
||||||
|
model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) |
||||||
|
``` |
||||||
|
|
||||||
|
之后根据需求可以进行量化,也可以直接使用: |
||||||
|
|
||||||
|
```python |
||||||
|
# Comment out the following line if you don't use quantization |
||||||
|
model = model.quantize(4) |
||||||
|
model = model.half().cuda() |
||||||
|
model.transformer.prefix_encoder.float() |
||||||
|
model = model.eval() |
||||||
|
|
||||||
|
response, history = model.chat(tokenizer, "你好", history=[]) |
||||||
|
``` |
||||||
|
|
||||||
|
## 使用自己的数据集 |
||||||
|
修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。 |
||||||
|
|
||||||
|
## 对话数据集 |
||||||
|
|
||||||
|
如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如 |
||||||
|
|
||||||
|
```json |
||||||
|
{ |
||||||
|
"prompt": "是的。上下水管都好的", |
||||||
|
"response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", |
||||||
|
"history": [ |
||||||
|
[ |
||||||
|
"长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", |
||||||
|
"用电脑能读数据流吗?水温多少" |
||||||
|
], |
||||||
|
[ |
||||||
|
"95", |
||||||
|
"上下水管温差怎么样啊?空气是不是都排干净了呢?" |
||||||
|
] |
||||||
|
] |
||||||
|
} |
||||||
|
``` |
||||||
|
|
||||||
|
训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接,例如: |
||||||
|
|
||||||
|
- Input |
||||||
|
|
||||||
|
``` |
||||||
|
[Round 0] |
||||||
|
问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线 |
||||||
|
答:用电脑能读数据流吗?水温多少 |
||||||
|
[Round 1] |
||||||
|
问:95 |
||||||
|
答:上下水管温差怎么样啊?空气是不是都排干净了呢? |
||||||
|
[Round 2] |
||||||
|
问:是的。上下水管都好的 |
||||||
|
答: |
||||||
|
``` |
||||||
|
|
||||||
|
- Label |
||||||
|
|
||||||
|
``` |
||||||
|
那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况! |
||||||
|
``` |
||||||
|
|
||||||
|
要注意超过输入长度 `max_source_length` 的内容会被截。 |
||||||
|
|
||||||
|
可以参考以下指令: |
||||||
|
|
||||||
|
```shell |
||||||
|
bash train_chat.sh |
||||||
|
``` |
||||||
|
|
||||||
|
## 引用 |
||||||
|
|
||||||
|
``` |
||||||
|
@inproceedings{liu2022p, |
||||||
|
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, |
||||||
|
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, |
||||||
|
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, |
||||||
|
pages={61--68}, |
||||||
|
year={2022} |
||||||
|
} |
||||||
|
``` |
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,115 @@ |
|||||||
|
# ChatGLM-6B-PT |
||||||
|
This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run. |
||||||
|
|
||||||
|
The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code. |
||||||
|
|
||||||
|
## Software dependencies |
||||||
|
Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required |
||||||
|
``` |
||||||
|
pip install rouge_chinese nltk jieba datasets |
||||||
|
``` |
||||||
|
## Instructions |
||||||
|
|
||||||
|
### Download the dataset |
||||||
|
The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content). |
||||||
|
|
||||||
|
```json |
||||||
|
{ |
||||||
|
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", |
||||||
|
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" |
||||||
|
} |
||||||
|
``` |
||||||
|
|
||||||
|
From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) Download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory. |
||||||
|
|
||||||
|
### Training |
||||||
|
Run the following commands for training: |
||||||
|
```shell |
||||||
|
bash train.sh |
||||||
|
``` |
||||||
|
`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision. |
||||||
|
|
||||||
|
Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation. |
||||||
|
|
||||||
|
### Inference |
||||||
|
|
||||||
|
Change `CHECKPOINT` in `evaluate.sh` to the checkpoint name saved during training, and run the following commands for model inference and evaluation: |
||||||
|
```shell |
||||||
|
bash evaluate.sh |
||||||
|
``` |
||||||
|
|
||||||
|
The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in |
||||||
|
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`. |
||||||
|
|
||||||
|
### Example |
||||||
|
#### Example 1 |
||||||
|
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 |
||||||
|
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 |
||||||
|
* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 |
||||||
|
* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 |
||||||
|
|
||||||
|
#### Example 2 |
||||||
|
|
||||||
|
* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 |
||||||
|
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 |
||||||
|
* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 |
||||||
|
* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 |
||||||
|
|
||||||
|
### evaluation result |
||||||
|
|
||||||
|
| | P-tuning v2 | LoRA | |
||||||
|
| ------- | ----------- | ----- | |
||||||
|
| BLEU-4 | 7.71 | 6.13 | |
||||||
|
| Rouge-1 | 31.35 | 28.36 | |
||||||
|
| Rouge-2 | 7.19 | 4.38 | |
||||||
|
| Rouge-l | 25.17 | 17.54 | |
||||||
|
|
||||||
|
#### Experiment Settings |
||||||
|
|
||||||
|
``` |
||||||
|
max_source_length=64 |
||||||
|
max_target_length=64 |
||||||
|
per_device_train_batch_size=1 |
||||||
|
gradient_accumulation_steps=16 |
||||||
|
max_steps=3000 |
||||||
|
``` |
||||||
|
|
||||||
|
##### P-tuning v2 |
||||||
|
|
||||||
|
``` |
||||||
|
pre_seq_len=128 |
||||||
|
learning_rate=2e-2 |
||||||
|
quantization_bit=4 |
||||||
|
``` |
||||||
|
|
||||||
|
##### LoRA |
||||||
|
|
||||||
|
``` |
||||||
|
learning_rate=5e-4 |
||||||
|
``` |
||||||
|
|
||||||
|
The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Model Deployment |
||||||
|
Replace `THUDM/chatglm-6b` in the corresponding demo or code with the path of the checkpoint after P-Tuning(in the example, `./output/adgen-chatglm-6b-pt-8-1e-2/ checkpoint-3000`). Note that the current fine-tuning does not support multiple rounds of data, so only the responses from the first round of the conversation are fine-tuned. |
||||||
|
|
||||||
|
## Use your own dataset |
||||||
|
Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text. |
||||||
|
|
||||||
|
## TODO |
||||||
|
* [ ] Support for chat data |
||||||
|
* [ ] Support for full finetuning |
||||||
|
|
||||||
|
## quoting |
||||||
|
|
||||||
|
``` |
||||||
|
@inproceedings{liu2022p, |
||||||
|
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, |
||||||
|
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, |
||||||
|
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, |
||||||
|
pages={61--68}, |
||||||
|
year={2022} |
||||||
|
} |
||||||
|
``` |
@ -0,0 +1,224 @@ |
|||||||
|
from dataclasses import dataclass, field |
||||||
|
from typing import Optional |
||||||
|
|
||||||
|
|
||||||
|
@dataclass |
||||||
|
class ModelArguments: |
||||||
|
""" |
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
||||||
|
""" |
||||||
|
|
||||||
|
model_name_or_path: str = field( |
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
||||||
|
) |
||||||
|
ptuning_checkpoint: str = field( |
||||||
|
default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} |
||||||
|
) |
||||||
|
config_name: Optional[str] = field( |
||||||
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
||||||
|
) |
||||||
|
tokenizer_name: Optional[str] = field( |
||||||
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
||||||
|
) |
||||||
|
cache_dir: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, |
||||||
|
) |
||||||
|
use_fast_tokenizer: bool = field( |
||||||
|
default=True, |
||||||
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
||||||
|
) |
||||||
|
model_revision: str = field( |
||||||
|
default="main", |
||||||
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
||||||
|
) |
||||||
|
use_auth_token: bool = field( |
||||||
|
default=False, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"Will use the token generated when running `huggingface-cli login` (necessary to use this script " |
||||||
|
"with private models)." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
resize_position_embeddings: Optional[bool] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"Whether to automatically resize the position embeddings if `max_source_length` exceeds " |
||||||
|
"the model's position embeddings." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
quantization_bit: Optional[int] = field( |
||||||
|
default=None |
||||||
|
) |
||||||
|
pre_seq_len: Optional[int] = field( |
||||||
|
default=None |
||||||
|
) |
||||||
|
prefix_projection: bool = field( |
||||||
|
default=False |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
@dataclass |
||||||
|
class DataTrainingArguments: |
||||||
|
""" |
||||||
|
Arguments pertaining to what data we are going to input our model for training and eval. |
||||||
|
""" |
||||||
|
|
||||||
|
lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) |
||||||
|
|
||||||
|
dataset_name: Optional[str] = field( |
||||||
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
||||||
|
) |
||||||
|
dataset_config_name: Optional[str] = field( |
||||||
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
||||||
|
) |
||||||
|
prompt_column: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, |
||||||
|
) |
||||||
|
response_column: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, |
||||||
|
) |
||||||
|
history_column: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={"help": "The name of the column in the datasets containing the history of chat."}, |
||||||
|
) |
||||||
|
train_file: Optional[str] = field( |
||||||
|
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} |
||||||
|
) |
||||||
|
validation_file: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
test_file: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." |
||||||
|
}, |
||||||
|
) |
||||||
|
overwrite_cache: bool = field( |
||||||
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
||||||
|
) |
||||||
|
preprocessing_num_workers: Optional[int] = field( |
||||||
|
default=None, |
||||||
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
||||||
|
) |
||||||
|
max_source_length: Optional[int] = field( |
||||||
|
default=1024, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"The maximum total input sequence length after tokenization. Sequences longer " |
||||||
|
"than this will be truncated, sequences shorter will be padded." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
max_target_length: Optional[int] = field( |
||||||
|
default=128, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"The maximum total sequence length for target text after tokenization. Sequences longer " |
||||||
|
"than this will be truncated, sequences shorter will be padded." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
val_max_target_length: Optional[int] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"The maximum total sequence length for validation target text after tokenization. Sequences longer " |
||||||
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." |
||||||
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " |
||||||
|
"during ``evaluate`` and ``predict``." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
pad_to_max_length: bool = field( |
||||||
|
default=False, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"Whether to pad all samples to model maximum sentence length. " |
||||||
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More " |
||||||
|
"efficient on GPU but very bad for TPU." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
max_train_samples: Optional[int] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this " |
||||||
|
"value if set." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
max_eval_samples: Optional[int] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
||||||
|
"value if set." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
max_predict_samples: Optional[int] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this " |
||||||
|
"value if set." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
num_beams: Optional[int] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " |
||||||
|
"which is used during ``evaluate`` and ``predict``." |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
ignore_pad_token_for_loss: bool = field( |
||||||
|
default=True, |
||||||
|
metadata={ |
||||||
|
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." |
||||||
|
}, |
||||||
|
) |
||||||
|
source_prefix: Optional[str] = field( |
||||||
|
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} |
||||||
|
) |
||||||
|
|
||||||
|
forced_bos_token: Optional[str] = field( |
||||||
|
default=None, |
||||||
|
metadata={ |
||||||
|
"help": ( |
||||||
|
"The token to force as the first generated token after the decoder_start_token_id." |
||||||
|
"Useful for multilingual models like mBART where the first generated token" |
||||||
|
"needs to be the target language token (Usually it is the target language token)" |
||||||
|
) |
||||||
|
}, |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self): |
||||||
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: |
||||||
|
raise ValueError("Need either a dataset name or a training/validation/test file.") |
||||||
|
else: |
||||||
|
if self.train_file is not None: |
||||||
|
extension = self.train_file.split(".")[-1] |
||||||
|
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." |
||||||
|
if self.validation_file is not None: |
||||||
|
extension = self.validation_file.split(".")[-1] |
||||||
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." |
||||||
|
if self.val_max_target_length is None: |
||||||
|
self.val_max_target_length = self.max_target_length |
||||||
|
|
@ -0,0 +1,21 @@ |
|||||||
|
{ |
||||||
|
"train_micro_batch_size_per_gpu": "auto", |
||||||
|
"zero_allow_untested_optimizer": true, |
||||||
|
"fp16": { |
||||||
|
"enabled": "auto", |
||||||
|
"loss_scale": 0, |
||||||
|
"initial_scale_power": 16, |
||||||
|
"loss_scale_window": 1000, |
||||||
|
"hysteresis": 2, |
||||||
|
"min_loss_scale": 1 |
||||||
|
}, |
||||||
|
"zero_optimization": { |
||||||
|
"stage": 2, |
||||||
|
"allgather_partitions": true, |
||||||
|
"allgather_bucket_size": 5e8, |
||||||
|
"overlap_comm": false, |
||||||
|
"reduce_scatter": true, |
||||||
|
"reduce_bucket_size": 5e8, |
||||||
|
"contiguous_gradients" : true |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,28 @@ |
|||||||
|
|
||||||
|
LR=1e-4 |
||||||
|
|
||||||
|
MASTER_PORT=$(shuf -n 1 -i 10000-65535) |
||||||
|
|
||||||
|
deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \ |
||||||
|
--deepspeed deepspeed.json \ |
||||||
|
--do_train \ |
||||||
|
--train_file AdvertiseGen/train.json \ |
||||||
|
--test_file AdvertiseGen/dev.json \ |
||||||
|
--prompt_column content \ |
||||||
|
--response_column summary \ |
||||||
|
--overwrite_cache \ |
||||||
|
--model_name_or_path THUDM/chatglm-6b \ |
||||||
|
--output_dir ./output/adgen-chatglm-6b-ft-$LR \ |
||||||
|
--overwrite_output_dir \ |
||||||
|
--max_source_length 64 \ |
||||||
|
--max_target_length 64 \ |
||||||
|
--per_device_train_batch_size 4 \ |
||||||
|
--per_device_eval_batch_size 1 \ |
||||||
|
--gradient_accumulation_steps 1 \ |
||||||
|
--predict_with_generate \ |
||||||
|
--max_steps 5000 \ |
||||||
|
--logging_steps 10 \ |
||||||
|
--save_steps 1000 \ |
||||||
|
--learning_rate $LR \ |
||||||
|
--fp16 |
||||||
|
|
@ -0,0 +1,21 @@ |
|||||||
|
PRE_SEQ_LEN=128 |
||||||
|
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 |
||||||
|
STEP=3000 |
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \ |
||||||
|
--do_predict \ |
||||||
|
--validation_file AdvertiseGen/dev.json \ |
||||||
|
--test_file AdvertiseGen/dev.json \ |
||||||
|
--overwrite_cache \ |
||||||
|
--prompt_column content \ |
||||||
|
--response_column summary \ |
||||||
|
--model_name_or_path THUDM/chatglm-6b \ |
||||||
|
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \ |
||||||
|
--output_dir ./output/$CHECKPOINT \ |
||||||
|
--overwrite_output_dir \ |
||||||
|
--max_source_length 64 \ |
||||||
|
--max_target_length 64 \ |
||||||
|
--per_device_eval_batch_size 1 \ |
||||||
|
--predict_with_generate \ |
||||||
|
--pre_seq_len $PRE_SEQ_LEN \ |
||||||
|
--quantization_bit 4 |
@ -0,0 +1,18 @@ |
|||||||
|
CHECKPOINT=adgen-chatglm-6b-ft-1e-4 |
||||||
|
STEP=3000 |
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \ |
||||||
|
--do_predict \ |
||||||
|
--validation_file AdvertiseGen/dev.json \ |
||||||
|
--test_file AdvertiseGen/dev.json \ |
||||||
|
--overwrite_cache \ |
||||||
|
--prompt_column content \ |
||||||
|
--response_column summary \ |
||||||
|
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \ |
||||||
|
--output_dir ./output/$CHECKPOINT \ |
||||||
|
--overwrite_output_dir \ |
||||||
|
--max_source_length 256 \ |
||||||
|
--max_target_length 256 \ |
||||||
|
--per_device_eval_batch_size 1 \ |
||||||
|
--predict_with_generate \ |
||||||
|
--fp16_full_eval |
@ -0,0 +1,431 @@ |
|||||||
|
#!/usr/bin/env python |
||||||
|
# coding=utf-8 |
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
""" |
||||||
|
Fine-tuning the library models for sequence to sequence. |
||||||
|
""" |
||||||
|
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. |
||||||
|
|
||||||
|
import logging |
||||||
|
import os |
||||||
|
import sys |
||||||
|
import json |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
from datasets import load_dataset |
||||||
|
import jieba |
||||||
|
from rouge_chinese import Rouge |
||||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
||||||
|
import torch |
||||||
|
|
||||||
|
import transformers |
||||||
|
from transformers import ( |
||||||
|
AutoConfig, |
||||||
|
AutoModel, |
||||||
|
AutoTokenizer, |
||||||
|
AutoTokenizer, |
||||||
|
DataCollatorForSeq2Seq, |
||||||
|
HfArgumentParser, |
||||||
|
Seq2SeqTrainingArguments, |
||||||
|
set_seed, |
||||||
|
) |
||||||
|
from trainer_seq2seq import Seq2SeqTrainer |
||||||
|
|
||||||
|
from arguments import ModelArguments, DataTrainingArguments |
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
def main(): |
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) |
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
||||||
|
# If we pass only one argument to the script and it's the path to a json file, |
||||||
|
# let's parse it to get our arguments. |
||||||
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
||||||
|
else: |
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
||||||
|
|
||||||
|
# Setup logging |
||||||
|
logging.basicConfig( |
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
||||||
|
datefmt="%m/%d/%Y %H:%M:%S", |
||||||
|
handlers=[logging.StreamHandler(sys.stdout)], |
||||||
|
) |
||||||
|
|
||||||
|
if training_args.should_log: |
||||||
|
# The default of training_args.log_level is passive, so we set log level at info here to have that default. |
||||||
|
transformers.utils.logging.set_verbosity_info() |
||||||
|
|
||||||
|
log_level = training_args.get_process_log_level() |
||||||
|
logger.setLevel(log_level) |
||||||
|
# datasets.utils.logging.set_verbosity(log_level) |
||||||
|
transformers.utils.logging.set_verbosity(log_level) |
||||||
|
transformers.utils.logging.enable_default_handler() |
||||||
|
transformers.utils.logging.enable_explicit_format() |
||||||
|
|
||||||
|
# Log on each process the small summary: |
||||||
|
logger.warning( |
||||||
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
||||||
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
||||||
|
) |
||||||
|
logger.info(f"Training/evaluation parameters {training_args}") |
||||||
|
|
||||||
|
# Set seed before initializing model. |
||||||
|
set_seed(training_args.seed) |
||||||
|
|
||||||
|
# Load dataset |
||||||
|
data_files = {} |
||||||
|
if data_args.train_file is not None: |
||||||
|
data_files["train"] = data_args.train_file |
||||||
|
extension = data_args.train_file.split(".")[-1] |
||||||
|
if data_args.validation_file is not None: |
||||||
|
data_files["validation"] = data_args.validation_file |
||||||
|
extension = data_args.validation_file.split(".")[-1] |
||||||
|
if data_args.test_file is not None: |
||||||
|
data_files["test"] = data_args.test_file |
||||||
|
extension = data_args.test_file.split(".")[-1] |
||||||
|
|
||||||
|
raw_datasets = load_dataset( |
||||||
|
extension, |
||||||
|
data_files=data_files, |
||||||
|
cache_dir=model_args.cache_dir, |
||||||
|
use_auth_token=True if model_args.use_auth_token else None, |
||||||
|
) |
||||||
|
|
||||||
|
# Load pretrained model and tokenizer |
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
||||||
|
config.pre_seq_len = model_args.pre_seq_len |
||||||
|
config.prefix_projection = model_args.prefix_projection |
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
||||||
|
|
||||||
|
if model_args.ptuning_checkpoint is not None: |
||||||
|
# Evaluation |
||||||
|
# Loading extra state dict of prefix encoder |
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) |
||||||
|
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) |
||||||
|
new_prefix_state_dict = {} |
||||||
|
for k, v in prefix_state_dict.items(): |
||||||
|
if k.startswith("transformer.prefix_encoder."): |
||||||
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v |
||||||
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) |
||||||
|
else: |
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) |
||||||
|
|
||||||
|
if model_args.quantization_bit is not None: |
||||||
|
print(f"Quantized to {model_args.quantization_bit} bit") |
||||||
|
model = model.quantize(model_args.quantization_bit) |
||||||
|
if model_args.pre_seq_len is not None: |
||||||
|
# P-tuning v2 |
||||||
|
model = model.half() |
||||||
|
model.transformer.prefix_encoder.float() |
||||||
|
else: |
||||||
|
# Finetune |
||||||
|
model = model.float() |
||||||
|
|
||||||
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
||||||
|
|
||||||
|
# Preprocessing the datasets. |
||||||
|
# We need to tokenize inputs and targets. |
||||||
|
if training_args.do_train: |
||||||
|
column_names = raw_datasets["train"].column_names |
||||||
|
elif training_args.do_eval: |
||||||
|
column_names = raw_datasets["validation"].column_names |
||||||
|
elif training_args.do_predict: |
||||||
|
column_names = raw_datasets["test"].column_names |
||||||
|
else: |
||||||
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") |
||||||
|
return |
||||||
|
|
||||||
|
# Get the column names for input/target. |
||||||
|
prompt_column = data_args.prompt_column |
||||||
|
response_column = data_args.response_column |
||||||
|
history_column = data_args.history_column |
||||||
|
|
||||||
|
# Temporarily set max_target_length for training. |
||||||
|
max_target_length = data_args.max_target_length |
||||||
|
|
||||||
|
def preprocess_function_eval(examples): |
||||||
|
inputs, targets = [], [] |
||||||
|
for i in range(len(examples[prompt_column])): |
||||||
|
if examples[prompt_column][i] and examples[response_column][i]: |
||||||
|
query = examples[prompt_column][i] |
||||||
|
if history_column is None or len(examples[history_column][i]) == 0: |
||||||
|
prompt = query |
||||||
|
else: |
||||||
|
prompt = "" |
||||||
|
history = examples[history_column][i] |
||||||
|
for turn_idx, (old_query, response) in enumerate(history): |
||||||
|
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) |
||||||
|
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) |
||||||
|
inputs.append(prompt) |
||||||
|
targets.append(examples[response_column][i]) |
||||||
|
|
||||||
|
inputs = [prefix + inp for inp in inputs] |
||||||
|
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) |
||||||
|
labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) |
||||||
|
|
||||||
|
if data_args.ignore_pad_token_for_loss: |
||||||
|
labels["input_ids"] = [ |
||||||
|
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] |
||||||
|
] |
||||||
|
model_inputs["labels"] = labels["input_ids"] |
||||||
|
|
||||||
|
return model_inputs |
||||||
|
|
||||||
|
def preprocess_function_train(examples): |
||||||
|
max_seq_length = data_args.max_source_length + data_args.max_target_length |
||||||
|
|
||||||
|
model_inputs = { |
||||||
|
"input_ids": [], |
||||||
|
"labels": [], |
||||||
|
} |
||||||
|
for i in range(len(examples[prompt_column])): |
||||||
|
if examples[prompt_column][i] and examples[response_column][i]: |
||||||
|
query, answer = examples[prompt_column][i], examples[response_column][i] |
||||||
|
|
||||||
|
if history_column is None: |
||||||
|
prompt = query |
||||||
|
else: |
||||||
|
prompt = "" |
||||||
|
history = examples[history_column][i] |
||||||
|
for turn_idx, (old_query, response) in enumerate(history): |
||||||
|
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) |
||||||
|
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) |
||||||
|
|
||||||
|
prompt = prefix + prompt |
||||||
|
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) |
||||||
|
b_ids = tokenizer.encode(text=answer, add_special_tokens=False) |
||||||
|
|
||||||
|
if len(a_ids) > data_args.max_source_length - 1: |
||||||
|
a_ids = a_ids[: data_args.max_source_length - 1] |
||||||
|
|
||||||
|
if len(b_ids) > data_args.max_target_length - 2: |
||||||
|
b_ids = b_ids[: data_args.max_target_length - 2] |
||||||
|
|
||||||
|
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) |
||||||
|
|
||||||
|
context_length = input_ids.index(tokenizer.bos_token_id) |
||||||
|
mask_position = context_length - 1 |
||||||
|
labels = [-100] * context_length + input_ids[mask_position+1:] |
||||||
|
|
||||||
|
pad_len = max_seq_length - len(input_ids) |
||||||
|
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len |
||||||
|
labels = labels + [tokenizer.pad_token_id] * pad_len |
||||||
|
if data_args.ignore_pad_token_for_loss: |
||||||
|
labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] |
||||||
|
|
||||||
|
model_inputs["input_ids"].append(input_ids) |
||||||
|
model_inputs["labels"].append(labels) |
||||||
|
|
||||||
|
return model_inputs |
||||||
|
|
||||||
|
def print_dataset_example(example): |
||||||
|
print("input_ids",example["input_ids"]) |
||||||
|
print("inputs", tokenizer.decode(example["input_ids"])) |
||||||
|
print("label_ids", example["labels"]) |
||||||
|
print("labels", tokenizer.decode(example["labels"])) |
||||||
|
|
||||||
|
if training_args.do_train: |
||||||
|
if "train" not in raw_datasets: |
||||||
|
raise ValueError("--do_train requires a train dataset") |
||||||
|
train_dataset = raw_datasets["train"] |
||||||
|
if data_args.max_train_samples is not None: |
||||||
|
max_train_samples = min(len(train_dataset), data_args.max_train_samples) |
||||||
|
train_dataset = train_dataset.select(range(max_train_samples)) |
||||||
|
with training_args.main_process_first(desc="train dataset map pre-processing"): |
||||||
|
train_dataset = train_dataset.map( |
||||||
|
preprocess_function_train, |
||||||
|
batched=True, |
||||||
|
num_proc=data_args.preprocessing_num_workers, |
||||||
|
remove_columns=column_names, |
||||||
|
load_from_cache_file=not data_args.overwrite_cache, |
||||||
|
desc="Running tokenizer on train dataset", |
||||||
|
) |
||||||
|
print_dataset_example(train_dataset[0]) |
||||||
|
|
||||||
|
if training_args.do_eval: |
||||||
|
max_target_length = data_args.val_max_target_length |
||||||
|
if "validation" not in raw_datasets: |
||||||
|
raise ValueError("--do_eval requires a validation dataset") |
||||||
|
eval_dataset = raw_datasets["validation"] |
||||||
|
if data_args.max_eval_samples is not None: |
||||||
|
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) |
||||||
|
eval_dataset = eval_dataset.select(range(max_eval_samples)) |
||||||
|
with training_args.main_process_first(desc="validation dataset map pre-processing"): |
||||||
|
eval_dataset = eval_dataset.map( |
||||||
|
preprocess_function_eval, |
||||||
|
batched=True, |
||||||
|
num_proc=data_args.preprocessing_num_workers, |
||||||
|
remove_columns=column_names, |
||||||
|
load_from_cache_file=not data_args.overwrite_cache, |
||||||
|
desc="Running tokenizer on validation dataset", |
||||||
|
) |
||||||
|
print_dataset_example(eval_dataset[0]) |
||||||
|
|
||||||
|
if training_args.do_predict: |
||||||
|
max_target_length = data_args.val_max_target_length |
||||||
|
if "test" not in raw_datasets: |
||||||
|
raise ValueError("--do_predict requires a test dataset") |
||||||
|
predict_dataset = raw_datasets["test"] |
||||||
|
if data_args.max_predict_samples is not None: |
||||||
|
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) |
||||||
|
predict_dataset = predict_dataset.select(range(max_predict_samples)) |
||||||
|
with training_args.main_process_first(desc="prediction dataset map pre-processing"): |
||||||
|
predict_dataset = predict_dataset.map( |
||||||
|
preprocess_function_eval, |
||||||
|
batched=True, |
||||||
|
num_proc=data_args.preprocessing_num_workers, |
||||||
|
remove_columns=column_names, |
||||||
|
load_from_cache_file=not data_args.overwrite_cache, |
||||||
|
desc="Running tokenizer on prediction dataset", |
||||||
|
) |
||||||
|
print_dataset_example(predict_dataset[0]) |
||||||
|
|
||||||
|
# Data collator |
||||||
|
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id |
||||||
|
data_collator = DataCollatorForSeq2Seq( |
||||||
|
tokenizer, |
||||||
|
model=model, |
||||||
|
label_pad_token_id=label_pad_token_id, |
||||||
|
pad_to_multiple_of=None, |
||||||
|
padding=False |
||||||
|
) |
||||||
|
|
||||||
|
# Metric |
||||||
|
def compute_metrics(eval_preds): |
||||||
|
preds, labels = eval_preds |
||||||
|
if isinstance(preds, tuple): |
||||||
|
preds = preds[0] |
||||||
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
||||||
|
if data_args.ignore_pad_token_for_loss: |
||||||
|
# Replace -100 in the labels as we can't decode them. |
||||||
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
||||||
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
||||||
|
|
||||||
|
score_dict = { |
||||||
|
"rouge-1": [], |
||||||
|
"rouge-2": [], |
||||||
|
"rouge-l": [], |
||||||
|
"bleu-4": [] |
||||||
|
} |
||||||
|
for pred, label in zip(decoded_preds, decoded_labels): |
||||||
|
hypothesis = list(jieba.cut(pred)) |
||||||
|
reference = list(jieba.cut(label)) |
||||||
|
rouge = Rouge() |
||||||
|
scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) |
||||||
|
result = scores[0] |
||||||
|
|
||||||
|
for k, v in result.items(): |
||||||
|
score_dict[k].append(round(v["f"] * 100, 4)) |
||||||
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) |
||||||
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4)) |
||||||
|
|
||||||
|
for k, v in score_dict.items(): |
||||||
|
score_dict[k] = float(np.mean(v)) |
||||||
|
return score_dict |
||||||
|
|
||||||
|
# Override the decoding parameters of Seq2SeqTrainer |
||||||
|
training_args.generation_max_length = ( |
||||||
|
training_args.generation_max_length |
||||||
|
if training_args.generation_max_length is not None |
||||||
|
else data_args.val_max_target_length |
||||||
|
) |
||||||
|
training_args.generation_num_beams = ( |
||||||
|
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams |
||||||
|
) |
||||||
|
# Initialize our Trainer |
||||||
|
trainer = Seq2SeqTrainer( |
||||||
|
model=model, |
||||||
|
args=training_args, |
||||||
|
train_dataset=train_dataset if training_args.do_train else None, |
||||||
|
eval_dataset=eval_dataset if training_args.do_eval else None, |
||||||
|
tokenizer=tokenizer, |
||||||
|
data_collator=data_collator, |
||||||
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None, |
||||||
|
save_prefixencoder=model_args.pre_seq_len is not None |
||||||
|
) |
||||||
|
|
||||||
|
# Training |
||||||
|
if training_args.do_train: |
||||||
|
checkpoint = None |
||||||
|
if training_args.resume_from_checkpoint is not None: |
||||||
|
checkpoint = training_args.resume_from_checkpoint |
||||||
|
# elif last_checkpoint is not None: |
||||||
|
# checkpoint = last_checkpoint |
||||||
|
model.gradient_checkpointing_enable() |
||||||
|
model.enable_input_require_grads() |
||||||
|
train_result = trainer.train(resume_from_checkpoint=checkpoint) |
||||||
|
# trainer.save_model() # Saves the tokenizer too for easy upload |
||||||
|
|
||||||
|
metrics = train_result.metrics |
||||||
|
max_train_samples = ( |
||||||
|
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) |
||||||
|
) |
||||||
|
metrics["train_samples"] = min(max_train_samples, len(train_dataset)) |
||||||
|
|
||||||
|
trainer.log_metrics("train", metrics) |
||||||
|
trainer.save_metrics("train", metrics) |
||||||
|
trainer.save_state() |
||||||
|
|
||||||
|
# Evaluation |
||||||
|
results = {} |
||||||
|
if training_args.do_eval: |
||||||
|
logger.info("*** Evaluate ***") |
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95) |
||||||
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) |
||||||
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) |
||||||
|
|
||||||
|
trainer.log_metrics("eval", metrics) |
||||||
|
trainer.save_metrics("eval", metrics) |
||||||
|
|
||||||
|
if training_args.do_predict: |
||||||
|
logger.info("*** Predict ***") |
||||||
|
|
||||||
|
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95) |
||||||
|
metrics = predict_results.metrics |
||||||
|
max_predict_samples = ( |
||||||
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) |
||||||
|
) |
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) |
||||||
|
|
||||||
|
trainer.log_metrics("predict", metrics) |
||||||
|
trainer.save_metrics("predict", metrics) |
||||||
|
|
||||||
|
if trainer.is_world_process_zero(): |
||||||
|
if training_args.predict_with_generate: |
||||||
|
predictions = tokenizer.batch_decode( |
||||||
|
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True |
||||||
|
) |
||||||
|
predictions = [pred.strip() for pred in predictions] |
||||||
|
labels = tokenizer.batch_decode( |
||||||
|
predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
||||||
|
) |
||||||
|
labels = [label.strip() for label in labels] |
||||||
|
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") |
||||||
|
with open(output_prediction_file, "w", encoding="utf-8") as writer: |
||||||
|
for p, l in zip(predictions, labels): |
||||||
|
res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) |
||||||
|
writer.write(f"{res}\n") |
||||||
|
return results |
||||||
|
|
||||||
|
|
||||||
|
def _mp_fn(index): |
||||||
|
# For xla_spawn (TPUs) |
||||||
|
main() |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
@ -0,0 +1,26 @@ |
|||||||
|
PRE_SEQ_LEN=128 |
||||||
|
LR=2e-2 |
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \ |
||||||
|
--do_train \ |
||||||
|
--train_file AdvertiseGen/train.json \ |
||||||
|
--validation_file AdvertiseGen/dev.json \ |
||||||
|
--prompt_column content \ |
||||||
|
--response_column summary \ |
||||||
|
--overwrite_cache \ |
||||||
|
--model_name_or_path THUDM/chatglm-6b \ |
||||||
|
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ |
||||||
|
--overwrite_output_dir \ |
||||||
|
--max_source_length 64 \ |
||||||
|
--max_target_length 64 \ |
||||||
|
--per_device_train_batch_size 1 \ |
||||||
|
--per_device_eval_batch_size 1 \ |
||||||
|
--gradient_accumulation_steps 16 \ |
||||||
|
--predict_with_generate \ |
||||||
|
--max_steps 3000 \ |
||||||
|
--logging_steps 10 \ |
||||||
|
--save_steps 1000 \ |
||||||
|
--learning_rate $LR \ |
||||||
|
--pre_seq_len $PRE_SEQ_LEN \ |
||||||
|
--quantization_bit 4 |
||||||
|
|
@ -0,0 +1,27 @@ |
|||||||
|
PRE_SEQ_LEN=8 |
||||||
|
LR=1e-2 |
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \ |
||||||
|
--do_train \ |
||||||
|
--train_file $CHAT_TRAIN_DATA \ |
||||||
|
--validation_file $CHAT_VAL_DATA \ |
||||||
|
--prompt_column prompt \ |
||||||
|
--response_column response \ |
||||||
|
--history_column history \ |
||||||
|
--overwrite_cache \ |
||||||
|
--model_name_or_path THUDM/chatglm-6b \ |
||||||
|
--output_dir $CHECKPOINT_NAME \ |
||||||
|
--overwrite_output_dir \ |
||||||
|
--max_source_length 256 \ |
||||||
|
--max_target_length 256 \ |
||||||
|
--per_device_train_batch_size 1 \ |
||||||
|
--per_device_eval_batch_size 1 \ |
||||||
|
--gradient_accumulation_steps 16 \ |
||||||
|
--predict_with_generate \ |
||||||
|
--max_steps 3000 \ |
||||||
|
--logging_steps 10 \ |
||||||
|
--save_steps 1000 \ |
||||||
|
--learning_rate $LR \ |
||||||
|
--pre_seq_len $PRE_SEQ_LEN \ |
||||||
|
--quantization_bit 4 |
||||||
|
|
@ -0,0 +1,247 @@ |
|||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union |
||||||
|
|
||||||
|
import torch |
||||||
|
from torch import nn |
||||||
|
from torch.utils.data import Dataset |
||||||
|
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled |
||||||
|
from trainer import Trainer |
||||||
|
from transformers.trainer_utils import PredictionOutput |
||||||
|
from transformers.utils import logging |
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqTrainer(Trainer): |
||||||
|
def evaluate( |
||||||
|
self, |
||||||
|
eval_dataset: Optional[Dataset] = None, |
||||||
|
ignore_keys: Optional[List[str]] = None, |
||||||
|
metric_key_prefix: str = "eval", |
||||||
|
**gen_kwargs |
||||||
|
) -> Dict[str, float]: |
||||||
|
""" |
||||||
|
Run evaluation and returns metrics. |
||||||
|
|
||||||
|
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent |
||||||
|
(pass it to the init `compute_metrics` argument). |
||||||
|
|
||||||
|
You can also subclass and override this method to inject custom behavior. |
||||||
|
|
||||||
|
Args: |
||||||
|
eval_dataset (`Dataset`, *optional*): |
||||||
|
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns |
||||||
|
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` |
||||||
|
method. |
||||||
|
ignore_keys (`List[str]`, *optional*): |
||||||
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
||||||
|
gathering predictions. |
||||||
|
metric_key_prefix (`str`, *optional*, defaults to `"eval"`): |
||||||
|
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named |
||||||
|
"eval_bleu" if the prefix is `"eval"` (default) |
||||||
|
max_length (`int`, *optional*): |
||||||
|
The maximum target length to use when predicting with the generate method. |
||||||
|
num_beams (`int`, *optional*): |
||||||
|
Number of beams for beam search that will be used when predicting with the generate method. 1 means no |
||||||
|
beam search. |
||||||
|
gen_kwargs: |
||||||
|
Additional `generate` specific kwargs. |
||||||
|
|
||||||
|
Returns: |
||||||
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The |
||||||
|
dictionary also contains the epoch number which comes from the training state. |
||||||
|
""" |
||||||
|
|
||||||
|
gen_kwargs = gen_kwargs.copy() |
||||||
|
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
||||||
|
gen_kwargs["max_length"] = self.args.generation_max_length |
||||||
|
gen_kwargs["num_beams"] = ( |
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams |
||||||
|
) |
||||||
|
self._gen_kwargs = gen_kwargs |
||||||
|
|
||||||
|
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) |
||||||
|
|
||||||
|
def predict( |
||||||
|
self, |
||||||
|
test_dataset: Dataset, |
||||||
|
ignore_keys: Optional[List[str]] = None, |
||||||
|
metric_key_prefix: str = "test", |
||||||
|
**gen_kwargs |
||||||
|
) -> PredictionOutput: |
||||||
|
""" |
||||||
|
Run prediction and returns predictions and potential metrics. |
||||||
|
|
||||||
|
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method |
||||||
|
will also return metrics, like in `evaluate()`. |
||||||
|
|
||||||
|
Args: |
||||||
|
test_dataset (`Dataset`): |
||||||
|
Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the |
||||||
|
`model.forward()` method are automatically removed. Has to implement the method `__len__` |
||||||
|
ignore_keys (`List[str]`, *optional*): |
||||||
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
||||||
|
gathering predictions. |
||||||
|
metric_key_prefix (`str`, *optional*, defaults to `"eval"`): |
||||||
|
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named |
||||||
|
"eval_bleu" if the prefix is `"eval"` (default) |
||||||
|
max_length (`int`, *optional*): |
||||||
|
The maximum target length to use when predicting with the generate method. |
||||||
|
num_beams (`int`, *optional*): |
||||||
|
Number of beams for beam search that will be used when predicting with the generate method. 1 means no |
||||||
|
beam search. |
||||||
|
gen_kwargs: |
||||||
|
Additional `generate` specific kwargs. |
||||||
|
|
||||||
|
<Tip> |
||||||
|
|
||||||
|
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic |
||||||
|
padding in a token classification task) the predictions will be padded (on the right) to allow for |
||||||
|
concatenation into one array. The padding index is -100. |
||||||
|
|
||||||
|
</Tip> |
||||||
|
|
||||||
|
Returns: *NamedTuple* A namedtuple with the following keys: |
||||||
|
|
||||||
|
- predictions (`np.ndarray`): The predictions on `test_dataset`. |
||||||
|
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). |
||||||
|
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained |
||||||
|
labels). |
||||||
|
""" |
||||||
|
|
||||||
|
gen_kwargs = gen_kwargs.copy() |
||||||
|
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
||||||
|
gen_kwargs["max_length"] = self.args.generation_max_length |
||||||
|
gen_kwargs["num_beams"] = ( |
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams |
||||||
|
) |
||||||
|
self._gen_kwargs = gen_kwargs |
||||||
|
|
||||||
|
|
||||||
|
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) |
||||||
|
|
||||||
|
def prediction_step( |
||||||
|
self, |
||||||
|
model: nn.Module, |
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]], |
||||||
|
prediction_loss_only: bool, |
||||||
|
ignore_keys: Optional[List[str]] = None, |
||||||
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: |
||||||
|
""" |
||||||
|
Perform an evaluation step on `model` using `inputs`. |
||||||
|
|
||||||
|
Subclass and override to inject custom behavior. |
||||||
|
|
||||||
|
Args: |
||||||
|
model (`nn.Module`): |
||||||
|
The model to evaluate. |
||||||
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
||||||
|
The inputs and targets of the model. |
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
||||||
|
argument `labels`. Check your model's documentation for all accepted arguments. |
||||||
|
prediction_loss_only (`bool`): |
||||||
|
Whether or not to return the loss only. |
||||||
|
|
||||||
|
Return: |
||||||
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and |
||||||
|
labels (each being optional). |
||||||
|
""" |
||||||
|
|
||||||
|
if not self.args.predict_with_generate or prediction_loss_only: |
||||||
|
return super().prediction_step( |
||||||
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys |
||||||
|
) |
||||||
|
|
||||||
|
has_labels = "labels" in inputs |
||||||
|
inputs = self._prepare_inputs(inputs) |
||||||
|
|
||||||
|
# XXX: adapt synced_gpus for fairscale as well |
||||||
|
gen_kwargs = self._gen_kwargs.copy() |
||||||
|
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
||||||
|
gen_kwargs["max_length"] = self.model.config.max_length |
||||||
|
gen_kwargs["num_beams"] = ( |
||||||
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams |
||||||
|
) |
||||||
|
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False |
||||||
|
gen_kwargs["synced_gpus"] = ( |
||||||
|
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus |
||||||
|
) |
||||||
|
|
||||||
|
if "attention_mask" in inputs: |
||||||
|
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) |
||||||
|
if "position_ids" in inputs: |
||||||
|
gen_kwargs["position_ids"] = inputs.get("position_ids", None) |
||||||
|
if "global_attention_mask" in inputs: |
||||||
|
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) |
||||||
|
|
||||||
|
# prepare generation inputs |
||||||
|
# some encoder-decoder models can have varying encoder's and thus |
||||||
|
# varying model input names |
||||||
|
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: |
||||||
|
generation_inputs = inputs[self.model.encoder.main_input_name] |
||||||
|
else: |
||||||
|
generation_inputs = inputs[self.model.main_input_name] |
||||||
|
|
||||||
|
gen_kwargs["input_ids"] = generation_inputs |
||||||
|
generated_tokens = self.model.generate(**gen_kwargs) |
||||||
|
generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] |
||||||
|
|
||||||
|
# in case the batch is shorter than max length, the output should be padded |
||||||
|
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: |
||||||
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) |
||||||
|
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( |
||||||
|
gen_kwargs["max_new_tokens"] + 1 |
||||||
|
): |
||||||
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) |
||||||
|
|
||||||
|
loss = None |
||||||
|
|
||||||
|
if self.args.prediction_loss_only: |
||||||
|
return (loss, None, None) |
||||||
|
|
||||||
|
if has_labels: |
||||||
|
labels = inputs["labels"] |
||||||
|
if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: |
||||||
|
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) |
||||||
|
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( |
||||||
|
gen_kwargs["max_new_tokens"] + 1 |
||||||
|
): |
||||||
|
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) |
||||||
|
else: |
||||||
|
labels = None |
||||||
|
|
||||||
|
return (loss, generated_tokens, labels) |
||||||
|
|
||||||
|
def _pad_tensors_to_max_len(self, tensor, max_length): |
||||||
|
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): |
||||||
|
# If PAD token is not defined at least EOS token has to be defined |
||||||
|
pad_token_id = ( |
||||||
|
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
||||||
|
) |
||||||
|
else: |
||||||
|
if self.model.config.pad_token_id is not None: |
||||||
|
pad_token_id = self.model.config.pad_token_id |
||||||
|
else: |
||||||
|
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") |
||||||
|
|
||||||
|
padded_tensor = pad_token_id * torch.ones( |
||||||
|
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device |
||||||
|
) |
||||||
|
padded_tensor[:, : tensor.shape[-1]] = tensor |
||||||
|
return padded_tensor |
@ -1,7 +1,8 @@ |
|||||||
protobuf>=3.19.5,<3.20.1 |
protobuf |
||||||
transformers==4.26.1 |
transformers==4.27.1 |
||||||
icetk |
|
||||||
cpm_kernels |
cpm_kernels |
||||||
torch>=1.10 |
torch>=1.10 |
||||||
gradio |
gradio |
||||||
|
mdtex2html |
||||||
|
sentencepiece |
||||||
accelerate |
accelerate |
@ -0,0 +1,7 @@ |
|||||||
|
<div align="center"> |
||||||
|
<img src=wechat.jpg width="60%"/> |
||||||
|
|
||||||
|
<p> 扫码关注公众号,加入「ChatGLM交流群」 </p> |
||||||
|
<p> Scan the QR code to follow the official account and join the "ChatGLM Discussion Group" </p> |
||||||
|
</div> |
||||||
|
|
After Width: | Height: | Size: 151 KiB |
@ -1,44 +1,100 @@ |
|||||||
import gradio as gr |
import gradio as gr |
||||||
|
import mdtex2html |
||||||
|
|
||||||
from utils import load_model_and_tokenizer |
from utils import load_model_and_tokenizer |
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) |
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) |
||||||
|
|
||||||
MAX_TURNS = 20 |
"""Override Chatbot.postprocess""" |
||||||
MAX_BOXES = MAX_TURNS * 2 |
|
||||||
|
|
||||||
|
|
||||||
def predict(input, max_length, top_p, temperature, history=None): |
def postprocess(self, y): |
||||||
if history is None: |
if y is None: |
||||||
history = [] |
return [] |
||||||
|
for i, (message, response) in enumerate(y): |
||||||
|
y[i] = ( |
||||||
|
None if message is None else mdtex2html.convert((message)), |
||||||
|
None if response is None else mdtex2html.convert(response), |
||||||
|
) |
||||||
|
return y |
||||||
|
|
||||||
|
|
||||||
|
gr.Chatbot.postprocess = postprocess |
||||||
|
|
||||||
|
|
||||||
|
def parse_text(text): |
||||||
|
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" |
||||||
|
lines = text.split("\n") |
||||||
|
lines = [line for line in lines if line != ""] |
||||||
|
count = 0 |
||||||
|
for i, line in enumerate(lines): |
||||||
|
if "```" in line: |
||||||
|
count += 1 |
||||||
|
items = line.split('`') |
||||||
|
if count % 2 == 1: |
||||||
|
lines[i] = f'<pre><code class="language-{items[-1]}">' |
||||||
|
else: |
||||||
|
lines[i] = f'<br></code></pre>' |
||||||
|
else: |
||||||
|
if i > 0: |
||||||
|
if count % 2 == 1: |
||||||
|
line = line.replace("`", "\`") |
||||||
|
line = line.replace("<", "<") |
||||||
|
line = line.replace(">", ">") |
||||||
|
line = line.replace(" ", " ") |
||||||
|
line = line.replace("*", "*") |
||||||
|
line = line.replace("_", "_") |
||||||
|
line = line.replace("-", "-") |
||||||
|
line = line.replace(".", ".") |
||||||
|
line = line.replace("!", "!") |
||||||
|
line = line.replace("(", "(") |
||||||
|
line = line.replace(")", ")") |
||||||
|
line = line.replace("$", "$") |
||||||
|
lines[i] = "<br>"+line |
||||||
|
text = "".join(lines) |
||||||
|
return text |
||||||
|
|
||||||
|
|
||||||
|
def predict(input, chatbot, max_length, top_p, temperature, history): |
||||||
|
chatbot.append((parse_text(input), "")) |
||||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, |
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, |
||||||
temperature=temperature): |
temperature=temperature): |
||||||
updates = [] |
chatbot[-1] = (parse_text(input), parse_text(response)) |
||||||
for query, response in history: |
|
||||||
updates.append(gr.update(visible=True, value="用户:" + query)) |
yield chatbot, history |
||||||
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) |
|
||||||
if len(updates) < MAX_BOXES: |
|
||||||
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) |
def reset_user_input(): |
||||||
yield [history] + updates |
return gr.update(value='') |
||||||
|
|
||||||
|
|
||||||
|
def reset_state(): |
||||||
|
return [], [] |
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo: |
with gr.Blocks() as demo: |
||||||
state = gr.State([]) |
gr.HTML("""<h1 align="center">ChatGLM</h1>""") |
||||||
text_boxes = [] |
|
||||||
for i in range(MAX_BOXES): |
|
||||||
if i % 2 == 0: |
|
||||||
text_boxes.append(gr.Markdown(visible=False, label="提问:")) |
|
||||||
else: |
|
||||||
text_boxes.append(gr.Markdown(visible=False, label="回复:")) |
|
||||||
|
|
||||||
|
chatbot = gr.Chatbot() |
||||||
with gr.Row(): |
with gr.Row(): |
||||||
with gr.Column(scale=4): |
with gr.Column(scale=4): |
||||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( |
with gr.Column(scale=12): |
||||||
container=False) |
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( |
||||||
|
container=False) |
||||||
|
with gr.Column(min_width=32, scale=1): |
||||||
|
submitBtn = gr.Button("Submit", variant="primary") |
||||||
with gr.Column(scale=1): |
with gr.Column(scale=1): |
||||||
|
emptyBtn = gr.Button("Clear History") |
||||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) |
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) |
||||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) |
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) |
||||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) |
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) |
||||||
button = gr.Button("Generate") |
|
||||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) |
history = gr.State([]) |
||||||
|
|
||||||
|
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], |
||||||
|
show_progress=True) |
||||||
|
submitBtn.click(reset_user_input, [], [user_input]) |
||||||
|
|
||||||
|
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) |
||||||
|
|
||||||
demo.queue().launch(share=False, inbrowser=True) |
demo.queue().launch(share=False, inbrowser=True) |
||||||
|
@ -0,0 +1,45 @@ |
|||||||
|
from transformers import AutoModel, AutoTokenizer |
||||||
|
import gradio as gr |
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
||||||
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() |
||||||
|
model = model.eval() |
||||||
|
|
||||||
|
MAX_TURNS = 20 |
||||||
|
MAX_BOXES = MAX_TURNS * 2 |
||||||
|
|
||||||
|
|
||||||
|
def predict(input, max_length, top_p, temperature, history=None): |
||||||
|
if history is None: |
||||||
|
history = [] |
||||||
|
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, |
||||||
|
temperature=temperature): |
||||||
|
updates = [] |
||||||
|
for query, response in history: |
||||||
|
updates.append(gr.update(visible=True, value="用户:" + query)) |
||||||
|
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) |
||||||
|
if len(updates) < MAX_BOXES: |
||||||
|
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) |
||||||
|
yield [history] + updates |
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo: |
||||||
|
state = gr.State([]) |
||||||
|
text_boxes = [] |
||||||
|
for i in range(MAX_BOXES): |
||||||
|
if i % 2 == 0: |
||||||
|
text_boxes.append(gr.Markdown(visible=False, label="提问:")) |
||||||
|
else: |
||||||
|
text_boxes.append(gr.Markdown(visible=False, label="回复:")) |
||||||
|
|
||||||
|
with gr.Row(): |
||||||
|
with gr.Column(scale=4): |
||||||
|
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( |
||||||
|
container=False) |
||||||
|
with gr.Column(scale=1): |
||||||
|
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) |
||||||
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) |
||||||
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) |
||||||
|
button = gr.Button("Generate") |
||||||
|
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) |
||||||
|
demo.queue().launch(share=False, inbrowser=True) |
Loading…
Reference in new issue