mirror of https://github.com/THUDM/ChatGLM-6B
Merge branch 'dev' into dev_multi_gpu
# Conflicts: # README.md # api.py # cli_demo.py # requirements.txtdev_multi_gpu
commit
90f2e47f54
|
@ -0,0 +1,133 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
history/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Mac system file
|
||||
model/
|
|
@ -0,0 +1,18 @@
|
|||
# 友情链接
|
||||
|
||||
以下是部分基于本仓库开发的开源项目:
|
||||
* [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
|
||||
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
|
||||
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf)
|
||||
* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于本地知识的 ChatGLM 应用,基于LangChain
|
||||
* [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
|
||||
* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
|
||||
* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
|
||||
* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
|
||||
* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
|
||||
* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
|
||||
* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
|
||||
|
||||
以下是部分针对本项目的教程/文档:
|
||||
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
|
||||
* [ChatGLM-6B 的部署与微调教程 @ModelWhale平台](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
|
99
README.md
99
README.md
|
@ -1,39 +1,42 @@
|
|||
# ChatGLM-6B
|
||||
|
||||
<p align="center">
|
||||
🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
|
||||
</p>
|
||||
<p align="center">
|
||||
👋 加入我们的 <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1t4a8evfn-vduo2hhNcYqBUnZ71IXiqQ" target="_blank">Slack</a> 和 <a href="resources/WECHAT.md" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
## 介绍
|
||||
|
||||
ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。
|
||||
ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。更多信息请参考我们的[博客](https://chatglm.cn/blog)。
|
||||
ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答,更多信息请参考我们的[博客](https://chatglm.cn/blog)。
|
||||
|
||||
不过,由于 ChatGLM-6B 的规模较小,目前已知其具有相当多的[**局限性**](#局限性),如事实性/数学逻辑错误,可能生成有害/有偏见内容,较弱的上下文能力,自我认知混乱,以及对英文指示生成与中文指示完全矛盾的内容。请大家在使用前了解这些问题,以免产生误解。更大的基于1300亿参数[GLM-130B](https://github.com/THUDM/GLM-130B)的ChatGLM正在内测开发中。
|
||||
为了方便下游开发者针对自己的应用场景定制模型,我们同时实现了基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调方法 [(使用指南)](ptuning/README.md) ,INT4 量化级别下最低只需 7GB 显存即可启动微调。
|
||||
|
||||
不过,由于 ChatGLM-6B 的规模较小,目前已知其具有相当多的[**局限性**](#局限性),如事实性/数学逻辑错误,可能生成有害/有偏见内容,较弱的上下文能力,自我认知混乱,以及对英文指示生成与中文指示完全矛盾的内容。请大家在使用前了解这些问题,以免产生误解。更大的基于 1300 亿参数 [GLM-130B](https://github.com/THUDM/GLM-130B) 的 ChatGLM 正在内测开发中。
|
||||
|
||||
*Read this in [English](README_en.md).*
|
||||
|
||||
## 更新信息
|
||||
**[2023/03/23]** 增加API部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。增加Embedding量化模型[ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)。增加对基于Apple Silicon的Mac上GPU加速的支持。
|
||||
|
||||
**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)
|
||||
|
||||
## 友情链接
|
||||
以下是部分基于本仓库开发的开源项目:
|
||||
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小动态分配计算任务给 GPU 和 CPU
|
||||
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调
|
||||
部分基于本仓库开发的开源项目参见 [PROJECT.md](PROJECT.md)
|
||||
|
||||
如果你有其他好的项目的话,欢迎参照上述格式添加到README中并提出 [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
|
||||
如果你有其他好的项目/教程的话,欢迎参照上述格式添加到 README 中并提出 [Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)。
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 硬件需求
|
||||
|
||||
| **量化等级** | **最低 GPU 显存** |
|
||||
| -------------- | ----------------- |
|
||||
| FP16(无量化) | 13 GB |
|
||||
| INT8 | 10 GB |
|
||||
| INT4 | 6 GB |
|
||||
|
||||
| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
|
||||
| -------------- | ------------------------- | --------------------------------- |
|
||||
| FP16(无量化) | 13 GB | 14 GB |
|
||||
| INT8 | 8 GB | 9 GB |
|
||||
| INT4 | 6 GB | 7 GB |
|
||||
### 环境安装
|
||||
|
||||
使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.26.1`,但理论上不低于 `4.23.1` 即可。
|
||||
使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。
|
||||
|
||||
此外,如果需要在 cpu 上运行量化后的模型,还需要安装 `gcc` 与 `openmp`。多数 Linux 发行版默认已安装。对于 Windows ,可在安装 [TDM-GCC](https://jmeubank.github.io/tdm-gcc/) 时勾选 `openmp`。 Windows 测试环境 `gcc` 版本为 `TDM-GCC 10.3.0`, Linux 为 `gcc 11.3.0`。
|
||||
|
||||
### 代码调用
|
||||
|
||||
|
@ -60,9 +63,23 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
|
|||
|
||||
如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
|
||||
```
|
||||
完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b) 上查看。如果你从 Hugging Face Hub 上下载checkpoint的速度较慢,也可以从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载。
|
||||
### 从本地加载模型
|
||||
以上代码会由 `transformers` 自动下载模型实现和参数。完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b)。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。
|
||||
|
||||
### Demo
|
||||
从 Hugging Face Hub 下载模型需要先[安装Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行
|
||||
```Shell
|
||||
git clone https://huggingface.co/THUDM/chatglm-6b
|
||||
```
|
||||
|
||||
如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现
|
||||
```Shell
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
|
||||
```
|
||||
然后从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载模型参数文件,并将下载的文件替换到本地的 `chatglm-6b` 目录下。
|
||||
|
||||
将模型下载到本地之后,将以上代码中的 `THUDM/chatglm-6b` 替换为你本地的 `chatglm-6b` 文件夹的路径,即可从本地加载模型。
|
||||
|
||||
## Demo & API
|
||||
|
||||
我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库:
|
||||
|
||||
|
@ -95,14 +112,14 @@ python web_demo.py
|
|||
python cli_demo.py
|
||||
```
|
||||
|
||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
|
||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。
|
||||
|
||||
### API部署
|
||||
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py):
|
||||
首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
|
||||
```shell
|
||||
python api.py
|
||||
```
|
||||
默认部署在本地的8000端口,通过POST方法进行调用
|
||||
默认部署在本地的 8000 端口,通过 POST 方法进行调用
|
||||
```shell
|
||||
curl -X POST "http://127.0.0.1:8000" \
|
||||
-H 'Content-Type: application/json' \
|
||||
|
@ -124,44 +141,35 @@ curl -X POST "http://127.0.0.1:8000" \
|
|||
|
||||
```python
|
||||
# 按需修改,目前只支持 4/8 bit 量化
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda()
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(4).half().cuda()
|
||||
```
|
||||
|
||||
进行 2 至 3 轮对话后,8-bit 量化下 GPU 显存占用约为 10GB,4-bit 量化下仅需 6GB 占用。随着对话轮数的增多,对应消耗显存也随之增长,由于采用了相对位置编码,理论上 ChatGLM-6B 支持无限长的 context-length,但总长度超过 2048(训练长度)后性能会逐渐下降。
|
||||
|
||||
模型量化会带来一定的性能损失,经过测试,ChatGLM-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。使用 [GPT-Q](https://arxiv.org/abs/2210.17323) 等量化方案可以进一步压缩量化精度/提升相同量化精度下的模型性能,欢迎大家提出对应的 Pull Request。
|
||||
|
||||
**[2023/03/19]** 量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,仅需大概 5.2GB 的内存:
|
||||
量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,仅需大概 5.2GB 的内存:
|
||||
```python
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
|
||||
```
|
||||
|
||||
**[2023/03/24]** 我们进一步提供了对Embedding量化后的模型,模型参数仅占用4.3 GB显存:
|
||||
```python
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4-qe", trust_remote_code=True).half().cuda()
|
||||
```
|
||||
|
||||
|
||||
|
||||
### CPU 部署
|
||||
如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存)
|
||||
```python
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
|
||||
```
|
||||
|
||||
**[2023/03/19]** 如果你的内存不足,可以直接加载量化后的模型:
|
||||
如果你的内存不足,可以直接加载量化后的模型:
|
||||
```python
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
|
||||
```
|
||||
|
||||
如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) 的话请参考这个[Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).
|
||||
如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) ,请[从本地加载模型](README.md#从本地加载模型)
|
||||
|
||||
### Mac 上的 GPU 加速
|
||||
对于搭载了Apple Silicon的Mac(以及MacBook),可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。首先需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly。然后将模型仓库 clone 到本地
|
||||
```shell
|
||||
git clone https://huggingface.co/THUDM/chatglm-6b
|
||||
```
|
||||
将代码中的模型加载改为从本地加载,并使用 mps 后端
|
||||
对于搭载了Apple Silicon的Mac(以及MacBook),可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly。
|
||||
|
||||
目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端
|
||||
```python
|
||||
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
|
||||
```
|
||||
|
@ -178,6 +186,19 @@ from utils import load_model_and_tokenizer
|
|||
model, tokenizer = load_model_and_tokenizer("your local path", num_gpus=2)
|
||||
```
|
||||
即可将模型部署到多卡上进行推理。
|
||||
|
||||
## 高效参数微调
|
||||
基于 [P-tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调。具体使用方法详见 [ptuning/README.md](ptuning/README.md)。
|
||||
|
||||
## 更新信息
|
||||
**[2023/04/06]** 优化web demo的界面(感谢 [@tuteng0915](https://github.com/tuteng0915))。移除embedding中的image token以减小显存占用(需要更新模型文件`pytorch_model-00001-of-00008.bin`和`pytorch_model-00008-of-00008.bin`,感谢 [@silverriver](https://github.com/silverriver) 提出的想法)。去掉了对 `icetk` 的依赖(需要更新模型文件`ice_text.model`)。
|
||||
|
||||
**[2023/03/31]** 增加基于 [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调实现,INT4 量化级别下最低只需 7GB 显存即可进行模型微调。详见[高效参数微调方法](ptuning/README.md)。
|
||||
|
||||
**[2023/03/23]** 增加 API 部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。增加 Embedding 量化模型 [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)。增加配备 Apple Silicon 芯片的 Mac 上 GPU 加速的支持。
|
||||
|
||||
**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)
|
||||
|
||||
## ChatGLM-6B 示例
|
||||
|
||||
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
||||
|
|
41
README_en.md
41
README_en.md
|
@ -1,5 +1,13 @@
|
|||
# ChatGLM-6B
|
||||
|
||||
|
||||
<p align="center">
|
||||
🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2210.02414" target="_blank">[GLM-130B@ICLR 23]</a> <a href="https://github.com/THUDM/GLM-130B" target="_blank">[GitHub]</a> <br>
|
||||
</p>
|
||||
<p align="center">
|
||||
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1t4a8evfn-vduo2hhNcYqBUnZ71IXiqQ" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
## Introduction
|
||||
|
||||
ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level).
|
||||
|
@ -9,13 +17,15 @@ ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dial
|
|||
Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces.
|
||||
|
||||
## Update
|
||||
**[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
||||
**[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details.
|
||||
|
||||
**[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon.
|
||||
|
||||
**[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4).
|
||||
|
||||
## Projects
|
||||
The following are some open source projects developed based on this repository:
|
||||
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An [MNN](https://github.com/alibaba/MNN)-based implementation of ChatGLM-6B C++ inference, which supports dynamic allocation of computing tasks to GPU and CPU according to the size of GPU memory
|
||||
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An [MNN](https://github.com/alibaba/MNN)-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory
|
||||
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): Fine-tuning ChatGLM-6B based on LoRA
|
||||
|
||||
If you have other good projects, please refer to the above format to add to README and propose [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
|
||||
|
@ -32,7 +42,9 @@ If you have other good projects, please refer to the above format to add to READ
|
|||
|
||||
### Environment Setup
|
||||
|
||||
Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.26.1`, but theoretically any version no lower than `4.23.1` is acceptable.
|
||||
Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.27.1`, but theoretically any version no lower than `4.23.1` is acceptable.
|
||||
|
||||
In addition, if you need to run the quantified model on the CPU, you also need to install `gcc` and `openmp`. Most Linux distributions are installed by default. For Windows, you can check `openmp` when installing [TDM-GCC](https://jmeubank.github.io/tdm-gcc/). On Windows testing environment, the `gcc` version is `TDM-GCC 10.3.0`, and on Linux is `gcc 11.3.0`.
|
||||
|
||||
### Usage
|
||||
|
||||
|
@ -136,11 +148,6 @@ Model quantization brings a certain performance decline. After testing, ChatGLM-
|
|||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
|
||||
```
|
||||
|
||||
**[2023/03/24]** We further provide an embedding-quantized model whose model parameters only cost 4.3GB GPU memory
|
||||
```python
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4-qe", trust_remote_code=True).half().cuda()
|
||||
```
|
||||
|
||||
### CPU Deployment
|
||||
|
||||
If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory):
|
||||
|
@ -154,7 +161,23 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).fl
|
|||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
|
||||
```
|
||||
|
||||
**For Mac users**: if your encounter the error `RuntimeError: Unknown platform: darwin`, please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).
|
||||
If your encounter the error `Could not find module 'nvcuda.dll'` or `RuntimeError: Unknown platform: darwin`(MacOS), please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).
|
||||
|
||||
### GPU Inference on Mac
|
||||
For Macs (and MacBooks) with Apple Silicon, it is possible to use the MPS backend to run ChatGLM-6B on the GPU. First, you need to refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly. Then clone the model repository locally (you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage))
|
||||
```shell
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/chatglm-6b
|
||||
```
|
||||
Change the code to load the model from your local path, and use the mps backend:
|
||||
```python
|
||||
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
|
||||
```
|
||||
Then you can use GPU-accelerated model inference on Mac.
|
||||
|
||||
## Parameter-efficient Tuning
|
||||
Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it.
|
||||
|
||||
|
||||
### Multi-GPU Deployment
|
||||
|
||||
|
|
37
api.py
37
api.py
|
@ -1,10 +1,19 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import uvicorn, json, datetime
|
||||
import torch
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE_ID = "0"
|
||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
from utils import load_model_and_tokenizer
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
@ -17,7 +26,15 @@ async def create_item(request: Request):
|
|||
json_post_list = json.loads(json_post)
|
||||
prompt = json_post_list.get('prompt')
|
||||
history = json_post_list.get('history')
|
||||
response, history = model.chat(tokenizer, prompt, history=history)
|
||||
max_length = json_post_list.get('max_length')
|
||||
top_p = json_post_list.get('top_p')
|
||||
temperature = json_post_list.get('temperature')
|
||||
response, history = model.chat(tokenizer,
|
||||
prompt,
|
||||
history=history,
|
||||
max_length=max_length if max_length else 2048,
|
||||
top_p=top_p if top_p else 0.7,
|
||||
temperature=temperature if temperature else 0.95)
|
||||
now = datetime.datetime.now()
|
||||
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
answer = {
|
||||
|
@ -28,10 +45,12 @@ async def create_item(request: Request):
|
|||
}
|
||||
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
|
||||
print(log)
|
||||
torch_gc()
|
||||
return answer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1)
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model.eval()
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
|
|
32
cli_demo.py
32
cli_demo.py
|
@ -1,12 +1,15 @@
|
|||
import os
|
||||
import platform
|
||||
import signal
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
from utils import load_model_and_tokenizer
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
|
||||
os_name = platform.system()
|
||||
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
||||
stop_stream = False
|
||||
|
||||
|
||||
def build_prompt(history):
|
||||
|
@ -17,24 +20,35 @@ def build_prompt(history):
|
|||
return prompt
|
||||
|
||||
|
||||
def signal_handler(signal, frame):
|
||||
global stop_stream
|
||||
stop_stream = True
|
||||
|
||||
|
||||
def main():
|
||||
history = []
|
||||
global stop_stream
|
||||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||
while True:
|
||||
query = input("\n用户:")
|
||||
if query == "stop":
|
||||
if query.strip() == "stop":
|
||||
break
|
||||
if query == "clear":
|
||||
if query.strip() == "clear":
|
||||
history = []
|
||||
os.system(clear_command)
|
||||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||
continue
|
||||
count = 0
|
||||
for response, history in model.stream_chat(tokenizer, query, history=history):
|
||||
count += 1
|
||||
if count % 8 == 0:
|
||||
os.system(clear_command)
|
||||
print(build_prompt(history), flush=True)
|
||||
if stop_stream:
|
||||
stop_stream = False
|
||||
break
|
||||
else:
|
||||
count += 1
|
||||
if count % 8 == 0:
|
||||
os.system(clear_command)
|
||||
print(build_prompt(history), flush=True)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
os.system(clear_command)
|
||||
print(build_prompt(history), flush=True)
|
||||
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
# ChatGLM-6B-PT
|
||||
本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
|
||||
|
||||
下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。
|
||||
|
||||
*Read this in [English](README_en.md).*
|
||||
|
||||
## 软件依赖
|
||||
运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖
|
||||
```
|
||||
pip install rouge_chinese nltk jieba datasets
|
||||
```
|
||||
## 使用方法
|
||||
|
||||
### 下载数据集
|
||||
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
|
||||
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
|
||||
}
|
||||
```
|
||||
|
||||
从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
|
||||
|
||||
### 训练
|
||||
|
||||
#### P-tuning v2
|
||||
|
||||
运行以下指令进行训练:
|
||||
```shell
|
||||
bash train.sh
|
||||
```
|
||||
`train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
|
||||
|
||||
在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
|
||||
|
||||
如果你想要[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。
|
||||
|
||||
#### Finetune
|
||||
|
||||
如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令:
|
||||
|
||||
```shell
|
||||
bash ds_train_finetune.sh
|
||||
```
|
||||
|
||||
### 推理
|
||||
|
||||
将 `evaluate.sh` 中的 `CHECKPOINT` 更改为训练时保存的 checkpoint 名称,运行以下指令进行模型推理和评测:
|
||||
```shell
|
||||
bash evaluate.sh
|
||||
```
|
||||
**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定参数(已更新 `evaluate.sh`) :
|
||||
|
||||
```shell
|
||||
--model_name_or_path THUDM/chatglm-6b
|
||||
--ptuning_checkpoint $CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:
|
||||
|
||||
```shell
|
||||
--model_name_or_path $CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
|
||||
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
|
||||
|
||||
### 例子
|
||||
#### 示例1
|
||||
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
|
||||
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
|
||||
* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
|
||||
* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
|
||||
|
||||
#### 示例2
|
||||
|
||||
* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
|
||||
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
|
||||
* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
|
||||
* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
|
||||
|
||||
### 评估结果
|
||||
|
||||
| | Finetune | P-tuning v2 | LoRA |
|
||||
| ------------- | ----------- | ----- | ------------- |
|
||||
| BLEU-4 | 8.01 | 8.10 | 7.62 |
|
||||
| Rouge-1 | 31.23 | 31.12 | 30.60 |
|
||||
| Rouge-2 | 7.36 | 7.11 | 6.96 |
|
||||
| Rouge-l | 25.08 | 24.97 | 24.80 |
|
||||
| Training Loss | 3.00 | 3.74 | 3.32 |
|
||||
|
||||
|
||||
|
||||
#### 实验设置
|
||||
|
||||
```
|
||||
max_source_length=64
|
||||
max_target_length=64
|
||||
max_steps=3000
|
||||
```
|
||||
|
||||
##### P-tuning v2
|
||||
|
||||
```
|
||||
pre_seq_len=128
|
||||
learning_rate=2e-2
|
||||
quantization_bit=4
|
||||
per_device_train_batch_size=16
|
||||
gradient_accumulation_steps=1
|
||||
```
|
||||
|
||||
##### Finetune
|
||||
|
||||
```
|
||||
learning_rate=1e-4
|
||||
fp16
|
||||
num_gpus=4
|
||||
per_device_train_batch_size=4
|
||||
gradient_accumulation_steps=1
|
||||
```
|
||||
|
||||
##### LoRA
|
||||
|
||||
实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
|
||||
|
||||
```
|
||||
learning_rate=5e-4
|
||||
per_device_train_batch_size=16
|
||||
gradient_accumulation_steps=1
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 模型部署
|
||||
首先载入Tokenizer:
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
|
||||
# 载入Tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
```
|
||||
|
||||
1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
|
||||
|
||||
```python
|
||||
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
|
||||
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
```
|
||||
注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。
|
||||
|
||||
2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint:
|
||||
|
||||
```python
|
||||
model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
|
||||
```
|
||||
|
||||
之后根据需求可以进行量化,也可以直接使用:
|
||||
|
||||
```python
|
||||
# Comment out the following line if you don't use quantization
|
||||
model = model.quantize(4)
|
||||
model = model.half().cuda()
|
||||
model.transformer.prefix_encoder.float()
|
||||
model = model.eval()
|
||||
|
||||
response, history = model.chat(tokenizer, "你好", history=[])
|
||||
```
|
||||
|
||||
## 使用自己的数据集
|
||||
修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。
|
||||
|
||||
## 对话数据集
|
||||
|
||||
如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "是的。上下水管都好的",
|
||||
"response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!",
|
||||
"history": [
|
||||
[
|
||||
"长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线",
|
||||
"用电脑能读数据流吗?水温多少"
|
||||
],
|
||||
[
|
||||
"95",
|
||||
"上下水管温差怎么样啊?空气是不是都排干净了呢?"
|
||||
]
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接,例如:
|
||||
|
||||
- Input
|
||||
|
||||
```
|
||||
[Round 0]
|
||||
问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线
|
||||
答:用电脑能读数据流吗?水温多少
|
||||
[Round 1]
|
||||
问:95
|
||||
答:上下水管温差怎么样啊?空气是不是都排干净了呢?
|
||||
[Round 2]
|
||||
问:是的。上下水管都好的
|
||||
答:
|
||||
```
|
||||
|
||||
- Label
|
||||
|
||||
```
|
||||
那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!
|
||||
```
|
||||
|
||||
要注意超过输入长度 `max_source_length` 的内容会被截。
|
||||
|
||||
可以参考以下指令:
|
||||
|
||||
```shell
|
||||
bash train_chat.sh
|
||||
```
|
||||
|
||||
## 引用
|
||||
|
||||
```
|
||||
@inproceedings{liu2022p,
|
||||
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
|
||||
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
|
||||
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
|
||||
pages={61--68},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
# ChatGLM-6B-PT
|
||||
This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run.
|
||||
|
||||
The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code.
|
||||
|
||||
## Software dependencies
|
||||
Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required
|
||||
```
|
||||
pip install rouge_chinese nltk jieba datasets
|
||||
```
|
||||
## Instructions
|
||||
|
||||
### Download the dataset
|
||||
The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content).
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
|
||||
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
|
||||
}
|
||||
```
|
||||
|
||||
From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) Download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory.
|
||||
|
||||
### Training
|
||||
Run the following commands for training:
|
||||
```shell
|
||||
bash train.sh
|
||||
```
|
||||
`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision.
|
||||
|
||||
Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation.
|
||||
|
||||
### Inference
|
||||
|
||||
Change `CHECKPOINT` in `evaluate.sh` to the checkpoint name saved during training, and run the following commands for model inference and evaluation:
|
||||
```shell
|
||||
bash evaluate.sh
|
||||
```
|
||||
|
||||
The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in
|
||||
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`.
|
||||
|
||||
### Example
|
||||
#### Example 1
|
||||
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
|
||||
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
|
||||
* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
|
||||
* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
|
||||
|
||||
#### Example 2
|
||||
|
||||
* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
|
||||
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
|
||||
* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
|
||||
* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
|
||||
|
||||
### evaluation result
|
||||
|
||||
| | P-tuning v2 | LoRA |
|
||||
| ------- | ----------- | ----- |
|
||||
| BLEU-4 | 7.71 | 6.13 |
|
||||
| Rouge-1 | 31.35 | 28.36 |
|
||||
| Rouge-2 | 7.19 | 4.38 |
|
||||
| Rouge-l | 25.17 | 17.54 |
|
||||
|
||||
#### Experiment Settings
|
||||
|
||||
```
|
||||
max_source_length=64
|
||||
max_target_length=64
|
||||
per_device_train_batch_size=1
|
||||
gradient_accumulation_steps=16
|
||||
max_steps=3000
|
||||
```
|
||||
|
||||
##### P-tuning v2
|
||||
|
||||
```
|
||||
pre_seq_len=128
|
||||
learning_rate=2e-2
|
||||
quantization_bit=4
|
||||
```
|
||||
|
||||
##### LoRA
|
||||
|
||||
```
|
||||
learning_rate=5e-4
|
||||
```
|
||||
|
||||
The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
|
||||
|
||||
|
||||
|
||||
## Model Deployment
|
||||
Replace `THUDM/chatglm-6b` in the corresponding demo or code with the path of the checkpoint after P-Tuning(in the example, `./output/adgen-chatglm-6b-pt-8-1e-2/ checkpoint-3000`). Note that the current fine-tuning does not support multiple rounds of data, so only the responses from the first round of the conversation are fine-tuned.
|
||||
|
||||
## Use your own dataset
|
||||
Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text.
|
||||
|
||||
## TODO
|
||||
* [ ] Support for chat data
|
||||
* [ ] Support for full finetuning
|
||||
|
||||
## quoting
|
||||
|
||||
```
|
||||
@inproceedings{liu2022p,
|
||||
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
|
||||
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
|
||||
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
|
||||
pages={61--68},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,224 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
ptuning_checkpoint: str = field(
|
||||
default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
)
|
||||
},
|
||||
)
|
||||
resize_position_embeddings: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
|
||||
"the model's position embeddings."
|
||||
)
|
||||
},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None
|
||||
)
|
||||
pre_seq_len: Optional[int] = field(
|
||||
default=None
|
||||
)
|
||||
prefix_projection: bool = field(
|
||||
default=False
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
prompt_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
||||
)
|
||||
response_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
||||
)
|
||||
history_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the history of chat."},
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
|
||||
)
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
)
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
)
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The token to force as the first generated token after the decoder_start_token_id."
|
||||
"Useful for multilingual models like mBART where the first generated token"
|
||||
"needs to be the target language token (Usually it is the target language token)"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation/test file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"zero_allow_untested_optimizer": true,
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": false,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients" : true
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
|
||||
LR=1e-4
|
||||
|
||||
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
|
||||
|
||||
deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
|
||||
--deepspeed deepspeed.json \
|
||||
--do_train \
|
||||
--train_file AdvertiseGen/train.json \
|
||||
--test_file AdvertiseGen/dev.json \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--overwrite_cache \
|
||||
--model_name_or_path THUDM/chatglm-6b \
|
||||
--output_dir ./output/adgen-chatglm-6b-ft-$LR \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 64 \
|
||||
--max_target_length 64 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--predict_with_generate \
|
||||
--max_steps 5000 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate $LR \
|
||||
--fp16
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
PRE_SEQ_LEN=128
|
||||
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
|
||||
STEP=3000
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
--do_predict \
|
||||
--validation_file AdvertiseGen/dev.json \
|
||||
--test_file AdvertiseGen/dev.json \
|
||||
--overwrite_cache \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--model_name_or_path THUDM/chatglm-6b \
|
||||
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
|
||||
--output_dir ./output/$CHECKPOINT \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 64 \
|
||||
--max_target_length 64 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--predict_with_generate \
|
||||
--pre_seq_len $PRE_SEQ_LEN \
|
||||
--quantization_bit 4
|
|
@ -0,0 +1,18 @@
|
|||
CHECKPOINT=adgen-chatglm-6b-ft-1e-4
|
||||
STEP=3000
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
--do_predict \
|
||||
--validation_file AdvertiseGen/dev.json \
|
||||
--test_file AdvertiseGen/dev.json \
|
||||
--overwrite_cache \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
|
||||
--output_dir ./output/$CHECKPOINT \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 256 \
|
||||
--max_target_length 256 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--predict_with_generate \
|
||||
--fp16_full_eval
|
|
@ -0,0 +1,431 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for sequence to sequence.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from trainer_seq2seq import Seq2SeqTrainer
|
||||
|
||||
from arguments import ModelArguments, DataTrainingArguments
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
# datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Load dataset
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||
config.pre_seq_len = model_args.pre_seq_len
|
||||
config.prefix_projection = model_args.prefix_projection
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||
|
||||
if model_args.ptuning_checkpoint is not None:
|
||||
# Evaluation
|
||||
# Loading extra state dict of prefix encoder
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
print(f"Quantized to {model_args.quantization_bit} bit")
|
||||
model = model.quantize(model_args.quantization_bit)
|
||||
if model_args.pre_seq_len is not None:
|
||||
# P-tuning v2
|
||||
model = model.half()
|
||||
model.transformer.prefix_encoder.float()
|
||||
else:
|
||||
# Finetune
|
||||
model = model.float()
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
column_names = raw_datasets["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
column_names = raw_datasets["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
return
|
||||
|
||||
# Get the column names for input/target.
|
||||
prompt_column = data_args.prompt_column
|
||||
response_column = data_args.response_column
|
||||
history_column = data_args.history_column
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
|
||||
def preprocess_function_eval(examples):
|
||||
inputs, targets = [], []
|
||||
for i in range(len(examples[prompt_column])):
|
||||
if examples[prompt_column][i] and examples[response_column][i]:
|
||||
query = examples[prompt_column][i]
|
||||
if history_column is None or len(examples[history_column][i]) == 0:
|
||||
prompt = query
|
||||
else:
|
||||
prompt = ""
|
||||
history = examples[history_column][i]
|
||||
for turn_idx, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
inputs.append(prompt)
|
||||
targets.append(examples[response_column][i])
|
||||
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
|
||||
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_function_train(examples):
|
||||
max_seq_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"labels": [],
|
||||
}
|
||||
for i in range(len(examples[prompt_column])):
|
||||
if examples[prompt_column][i] and examples[response_column][i]:
|
||||
query, answer = examples[prompt_column][i], examples[response_column][i]
|
||||
|
||||
if history_column is None:
|
||||
prompt = query
|
||||
else:
|
||||
prompt = ""
|
||||
history = examples[history_column][i]
|
||||
for turn_idx, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
|
||||
prompt = prefix + prompt
|
||||
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
||||
if len(a_ids) > data_args.max_source_length - 1:
|
||||
a_ids = a_ids[: data_args.max_source_length - 1]
|
||||
|
||||
if len(b_ids) > data_args.max_target_length - 2:
|
||||
b_ids = b_ids[: data_args.max_target_length - 2]
|
||||
|
||||
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
|
||||
|
||||
context_length = input_ids.index(tokenizer.bos_token_id)
|
||||
mask_position = context_length - 1
|
||||
labels = [-100] * context_length + input_ids[mask_position+1:]
|
||||
|
||||
pad_len = max_seq_length - len(input_ids)
|
||||
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
|
||||
labels = labels + [tokenizer.pad_token_id] * pad_len
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_dataset_example(example):
|
||||
print("input_ids",example["input_ids"])
|
||||
print("inputs", tokenizer.decode(example["input_ids"]))
|
||||
print("label_ids", example["labels"])
|
||||
print("labels", tokenizer.decode(example["labels"]))
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = raw_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function_train,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on train dataset",
|
||||
)
|
||||
print_dataset_example(train_dataset[0])
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = raw_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function_eval,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on validation dataset",
|
||||
)
|
||||
print_dataset_example(eval_dataset[0])
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = raw_datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_function_eval,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on prediction dataset",
|
||||
)
|
||||
print_dataset_example(predict_dataset[0])
|
||||
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=None,
|
||||
padding=False
|
||||
)
|
||||
|
||||
# Metric
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
# Replace -100 in the labels as we can't decode them.
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
score_dict = {
|
||||
"rouge-1": [],
|
||||
"rouge-2": [],
|
||||
"rouge-l": [],
|
||||
"bleu-4": []
|
||||
}
|
||||
for pred, label in zip(decoded_preds, decoded_labels):
|
||||
hypothesis = list(jieba.cut(pred))
|
||||
reference = list(jieba.cut(label))
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
|
||||
result = scores[0]
|
||||
|
||||
for k, v in result.items():
|
||||
score_dict[k].append(round(v["f"] * 100, 4))
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
|
||||
for k, v in score_dict.items():
|
||||
score_dict[k] = float(np.mean(v))
|
||||
return score_dict
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
training_args.generation_num_beams = (
|
||||
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
)
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
save_prefixencoder=model_args.pre_seq_len is not None
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
# elif last_checkpoint is not None:
|
||||
# checkpoint = last_checkpoint
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
# trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95)
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95)
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
labels = tokenizer.batch_decode(
|
||||
predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
labels = [label.strip() for label in labels]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
for p, l in zip(predictions, labels):
|
||||
res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
|
||||
writer.write(f"{res}\n")
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,26 @@
|
|||
PRE_SEQ_LEN=128
|
||||
LR=2e-2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
--do_train \
|
||||
--train_file AdvertiseGen/train.json \
|
||||
--validation_file AdvertiseGen/dev.json \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--overwrite_cache \
|
||||
--model_name_or_path THUDM/chatglm-6b \
|
||||
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 64 \
|
||||
--max_target_length 64 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--predict_with_generate \
|
||||
--max_steps 3000 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate $LR \
|
||||
--pre_seq_len $PRE_SEQ_LEN \
|
||||
--quantization_bit 4
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
PRE_SEQ_LEN=8
|
||||
LR=1e-2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
--do_train \
|
||||
--train_file $CHAT_TRAIN_DATA \
|
||||
--validation_file $CHAT_VAL_DATA \
|
||||
--prompt_column prompt \
|
||||
--response_column response \
|
||||
--history_column history \
|
||||
--overwrite_cache \
|
||||
--model_name_or_path THUDM/chatglm-6b \
|
||||
--output_dir $CHECKPOINT_NAME \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 256 \
|
||||
--max_target_length 256 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--predict_with_generate \
|
||||
--max_steps 3000 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate $LR \
|
||||
--pre_seq_len $PRE_SEQ_LEN \
|
||||
--quantization_bit 4
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,247 @@
|
|||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from trainer import Trainer
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def evaluate(
|
||||
self,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
**gen_kwargs
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Run evaluation and returns metrics.
|
||||
|
||||
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
|
||||
(pass it to the init `compute_metrics` argument).
|
||||
|
||||
You can also subclass and override this method to inject custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (`Dataset`, *optional*):
|
||||
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
|
||||
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
|
||||
method.
|
||||
ignore_keys (`List[str]`, *optional*):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
|
||||
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||
"eval_bleu" if the prefix is `"eval"` (default)
|
||||
max_length (`int`, *optional*):
|
||||
The maximum target length to use when predicting with the generate method.
|
||||
num_beams (`int`, *optional*):
|
||||
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
||||
beam search.
|
||||
gen_kwargs:
|
||||
Additional `generate` specific kwargs.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||
dictionary also contains the epoch number which comes from the training state.
|
||||
"""
|
||||
|
||||
gen_kwargs = gen_kwargs.copy()
|
||||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||
gen_kwargs["num_beams"] = (
|
||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||
)
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
test_dataset: Dataset,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "test",
|
||||
**gen_kwargs
|
||||
) -> PredictionOutput:
|
||||
"""
|
||||
Run prediction and returns predictions and potential metrics.
|
||||
|
||||
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
|
||||
will also return metrics, like in `evaluate()`.
|
||||
|
||||
Args:
|
||||
test_dataset (`Dataset`):
|
||||
Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
|
||||
`model.forward()` method are automatically removed. Has to implement the method `__len__`
|
||||
ignore_keys (`List[str]`, *optional*):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
|
||||
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||
"eval_bleu" if the prefix is `"eval"` (default)
|
||||
max_length (`int`, *optional*):
|
||||
The maximum target length to use when predicting with the generate method.
|
||||
num_beams (`int`, *optional*):
|
||||
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
||||
beam search.
|
||||
gen_kwargs:
|
||||
Additional `generate` specific kwargs.
|
||||
|
||||
<Tip>
|
||||
|
||||
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
|
||||
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
||||
concatenation into one array. The padding index is -100.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns: *NamedTuple* A namedtuple with the following keys:
|
||||
|
||||
- predictions (`np.ndarray`): The predictions on `test_dataset`.
|
||||
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
|
||||
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
|
||||
labels).
|
||||
"""
|
||||
|
||||
gen_kwargs = gen_kwargs.copy()
|
||||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||
gen_kwargs["num_beams"] = (
|
||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||
)
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
|
||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on `model` using `inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (`bool`):
|
||||
Whether or not to return the loss only.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||
labels (each being optional).
|
||||
"""
|
||||
|
||||
if not self.args.predict_with_generate or prediction_loss_only:
|
||||
return super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
|
||||
has_labels = "labels" in inputs
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
# XXX: adapt synced_gpus for fairscale as well
|
||||
gen_kwargs = self._gen_kwargs.copy()
|
||||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
||||
gen_kwargs["max_length"] = self.model.config.max_length
|
||||
gen_kwargs["num_beams"] = (
|
||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
|
||||
)
|
||||
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
|
||||
gen_kwargs["synced_gpus"] = (
|
||||
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
|
||||
)
|
||||
|
||||
if "attention_mask" in inputs:
|
||||
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
|
||||
if "position_ids" in inputs:
|
||||
gen_kwargs["position_ids"] = inputs.get("position_ids", None)
|
||||
if "global_attention_mask" in inputs:
|
||||
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
|
||||
|
||||
# prepare generation inputs
|
||||
# some encoder-decoder models can have varying encoder's and thus
|
||||
# varying model input names
|
||||
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
|
||||
generation_inputs = inputs[self.model.encoder.main_input_name]
|
||||
else:
|
||||
generation_inputs = inputs[self.model.main_input_name]
|
||||
|
||||
gen_kwargs["input_ids"] = generation_inputs
|
||||
generated_tokens = self.model.generate(**gen_kwargs)
|
||||
generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
|
||||
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
|
||||
gen_kwargs["max_new_tokens"] + 1
|
||||
):
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
|
||||
|
||||
loss = None
|
||||
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
if has_labels:
|
||||
labels = inputs["labels"]
|
||||
if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
|
||||
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
||||
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
|
||||
gen_kwargs["max_new_tokens"] + 1
|
||||
):
|
||||
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
|
||||
else:
|
||||
labels = None
|
||||
|
||||
return (loss, generated_tokens, labels)
|
||||
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
||||
# If PAD token is not defined at least EOS token has to be defined
|
||||
pad_token_id = (
|
||||
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
||||
)
|
||||
else:
|
||||
if self.model.config.pad_token_id is not None:
|
||||
pad_token_id = self.model.config.pad_token_id
|
||||
else:
|
||||
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones(
|
||||
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||
return padded_tensor
|
|
@ -1,7 +1,8 @@
|
|||
protobuf>=3.19.5,<3.20.1
|
||||
transformers==4.26.1
|
||||
icetk
|
||||
protobuf
|
||||
transformers==4.27.1
|
||||
cpm_kernels
|
||||
torch>=1.10
|
||||
gradio
|
||||
mdtex2html
|
||||
sentencepiece
|
||||
accelerate
|
|
@ -0,0 +1,7 @@
|
|||
<div align="center">
|
||||
<img src=wechat.jpg width="60%"/>
|
||||
|
||||
<p> 扫码关注公众号,加入「ChatGLM交流群」 </p>
|
||||
<p> Scan the QR code to follow the official account and join the "ChatGLM Discussion Group" </p>
|
||||
</div>
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 151 KiB |
102
web_demo.py
102
web_demo.py
|
@ -1,44 +1,100 @@
|
|||
import gradio as gr
|
||||
import mdtex2html
|
||||
|
||||
from utils import load_model_and_tokenizer
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||
|
||||
MAX_TURNS = 20
|
||||
MAX_BOXES = MAX_TURNS * 2
|
||||
"""Override Chatbot.postprocess"""
|
||||
|
||||
|
||||
def predict(input, max_length, top_p, temperature, history=None):
|
||||
if history is None:
|
||||
history = []
|
||||
def postprocess(self, y):
|
||||
if y is None:
|
||||
return []
|
||||
for i, (message, response) in enumerate(y):
|
||||
y[i] = (
|
||||
None if message is None else mdtex2html.convert((message)),
|
||||
None if response is None else mdtex2html.convert(response),
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
gr.Chatbot.postprocess = postprocess
|
||||
|
||||
|
||||
def parse_text(text):
|
||||
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if line != ""]
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "```" in line:
|
||||
count += 1
|
||||
items = line.split('`')
|
||||
if count % 2 == 1:
|
||||
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
||||
else:
|
||||
lines[i] = f'<br></code></pre>'
|
||||
else:
|
||||
if i > 0:
|
||||
if count % 2 == 1:
|
||||
line = line.replace("`", "\`")
|
||||
line = line.replace("<", "<")
|
||||
line = line.replace(">", ">")
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace("*", "*")
|
||||
line = line.replace("_", "_")
|
||||
line = line.replace("-", "-")
|
||||
line = line.replace(".", ".")
|
||||
line = line.replace("!", "!")
|
||||
line = line.replace("(", "(")
|
||||
line = line.replace(")", ")")
|
||||
line = line.replace("$", "$")
|
||||
lines[i] = "<br>"+line
|
||||
text = "".join(lines)
|
||||
return text
|
||||
|
||||
|
||||
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||
chatbot.append((parse_text(input), ""))
|
||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
updates = []
|
||||
for query, response in history:
|
||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
||||
if len(updates) < MAX_BOXES:
|
||||
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
||||
yield [history] + updates
|
||||
chatbot[-1] = (parse_text(input), parse_text(response))
|
||||
|
||||
yield chatbot, history
|
||||
|
||||
|
||||
def reset_user_input():
|
||||
return gr.update(value='')
|
||||
|
||||
|
||||
def reset_state():
|
||||
return [], []
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
state = gr.State([])
|
||||
text_boxes = []
|
||||
for i in range(MAX_BOXES):
|
||||
if i % 2 == 0:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
||||
else:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
||||
gr.HTML("""<h1 align="center">ChatGLM</h1>""")
|
||||
|
||||
chatbot = gr.Chatbot()
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
||||
container=False)
|
||||
with gr.Column(scale=12):
|
||||
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
||||
container=False)
|
||||
with gr.Column(min_width=32, scale=1):
|
||||
submitBtn = gr.Button("Submit", variant="primary")
|
||||
with gr.Column(scale=1):
|
||||
emptyBtn = gr.Button("Clear History")
|
||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
button = gr.Button("Generate")
|
||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
|
||||
show_progress=True)
|
||||
submitBtn.click(reset_user_input, [], [user_input])
|
||||
|
||||
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
demo.queue().launch(share=False, inbrowser=True)
|
||||
|
|
16
web_demo2.py
16
web_demo2.py
|
@ -19,7 +19,7 @@ MAX_TURNS = 20
|
|||
MAX_BOXES = MAX_TURNS * 2
|
||||
|
||||
|
||||
def predict(input, history=None):
|
||||
def predict(input, max_length, top_p, temperature, history=None):
|
||||
tokenizer, model = get_model()
|
||||
if history is None:
|
||||
history = []
|
||||
|
@ -33,7 +33,8 @@ def predict(input, history=None):
|
|||
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
|
||||
st.write("AI正在回复:")
|
||||
with st.empty():
|
||||
for response, history in model.stream_chat(tokenizer, input, history):
|
||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
query, response = history[-1]
|
||||
st.write(response)
|
||||
|
||||
|
@ -47,6 +48,15 @@ prompt_text = st.text_area(label="用户命令输入",
|
|||
height = 100,
|
||||
placeholder="请在这儿输入您的命令")
|
||||
|
||||
max_length = st.sidebar.slider(
|
||||
'max_length', 0, 4096, 2048, step=1
|
||||
)
|
||||
top_p = st.sidebar.slider(
|
||||
'top_p', 0.0, 1.0, 0.6, step=0.01
|
||||
)
|
||||
temperature = st.sidebar.slider(
|
||||
'temperature', 0.0, 1.0, 0.95, step=0.01
|
||||
)
|
||||
|
||||
if 'state' not in st.session_state:
|
||||
st.session_state['state'] = []
|
||||
|
@ -54,4 +64,4 @@ if 'state' not in st.session_state:
|
|||
if st.button("发送", key="predict"):
|
||||
with st.spinner("AI正在思考,请稍等........"):
|
||||
# text generation
|
||||
st.session_state["state"] = predict(prompt_text, st.session_state["state"])
|
||||
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
|
|
@ -0,0 +1,45 @@
|
|||
from transformers import AutoModel, AutoTokenizer
|
||||
import gradio as gr
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
|
||||
MAX_TURNS = 20
|
||||
MAX_BOXES = MAX_TURNS * 2
|
||||
|
||||
|
||||
def predict(input, max_length, top_p, temperature, history=None):
|
||||
if history is None:
|
||||
history = []
|
||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
updates = []
|
||||
for query, response in history:
|
||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
||||
if len(updates) < MAX_BOXES:
|
||||
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
||||
yield [history] + updates
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
state = gr.State([])
|
||||
text_boxes = []
|
||||
for i in range(MAX_BOXES):
|
||||
if i % 2 == 0:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
||||
else:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
||||
container=False)
|
||||
with gr.Column(scale=1):
|
||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
button = gr.Button("Generate")
|
||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||
demo.queue().launch(share=False, inbrowser=True)
|
Loading…
Reference in New Issue