diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c3dd476 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +history/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Mac system file +model/ \ No newline at end of file diff --git a/PROJECT.md b/PROJECT.md new file mode 100644 index 0000000..4798259 --- /dev/null +++ b/PROJECT.md @@ -0,0 +1,18 @@ +# 友情链接 + +以下是部分基于本仓库开发的开源项目: +* [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。 +* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU +* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf) +* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于本地知识的 ChatGLM 应用,基于LangChain +* [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。 +* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能 +* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署 +* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。 +* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题 +* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能) +* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM + +以下是部分针对本项目的教程/文档: +* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md) +* [ChatGLM-6B 的部署与微调教程 @ModelWhale平台](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e) diff --git a/README.md b/README.md index 8e2bb86..9b08ff7 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,42 @@ # ChatGLM-6B +

+ 🌐 Blog • 🤗 HF Repo • 🐦 Twitter • 📃 [GLM@ACL 22] [GitHub] • 📃 [GLM-130B@ICLR 23] [GitHub]
+

+

+ 👋 加入我们的 SlackWeChat +

+ ## 介绍 ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。 -ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。更多信息请参考我们的[博客](https://chatglm.cn/blog)。 - -不过,由于 ChatGLM-6B 的规模较小,目前已知其具有相当多的[**局限性**](#局限性),如事实性/数学逻辑错误,可能生成有害/有偏见内容,较弱的上下文能力,自我认知混乱,以及对英文指示生成与中文指示完全矛盾的内容。请大家在使用前了解这些问题,以免产生误解。更大的基于1300亿参数[GLM-130B](https://github.com/THUDM/GLM-130B)的ChatGLM正在内测开发中。 +ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答,更多信息请参考我们的[博客](https://chatglm.cn/blog)。 -*Read this in [English](README_en.md).* +为了方便下游开发者针对自己的应用场景定制模型,我们同时实现了基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调方法 [(使用指南)](ptuning/README.md) ,INT4 量化级别下最低只需 7GB 显存即可启动微调。 -## 更新信息 -**[2023/03/23]** 增加API部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。增加Embedding量化模型[ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)。增加对基于Apple Silicon的Mac上GPU加速的支持。 +不过,由于 ChatGLM-6B 的规模较小,目前已知其具有相当多的[**局限性**](#局限性),如事实性/数学逻辑错误,可能生成有害/有偏见内容,较弱的上下文能力,自我认知混乱,以及对英文指示生成与中文指示完全矛盾的内容。请大家在使用前了解这些问题,以免产生误解。更大的基于 1300 亿参数 [GLM-130B](https://github.com/THUDM/GLM-130B) 的 ChatGLM 正在内测开发中。 -**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4) +*Read this in [English](README_en.md).* ## 友情链接 -以下是部分基于本仓库开发的开源项目: -* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小动态分配计算任务给 GPU 和 CPU -* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调 +部分基于本仓库开发的开源项目参见 [PROJECT.md](PROJECT.md) -如果你有其他好的项目的话,欢迎参照上述格式添加到README中并提出 [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). +如果你有其他好的项目/教程的话,欢迎参照上述格式添加到 README 中并提出 [Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)。 ## 使用方式 ### 硬件需求 -| **量化等级** | **最低 GPU 显存** | -| -------------- | ----------------- | -| FP16(无量化) | 13 GB | -| INT8 | 10 GB | -| INT4 | 6 GB | - +| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) | +| -------------- | ------------------------- | --------------------------------- | +| FP16(无量化) | 13 GB | 14 GB | +| INT8 | 8 GB | 9 GB | +| INT4 | 6 GB | 7 GB | ### 环境安装 -使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.26.1`,但理论上不低于 `4.23.1` 即可。 +使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。 + +此外,如果需要在 cpu 上运行量化后的模型,还需要安装 `gcc` 与 `openmp`。多数 Linux 发行版默认已安装。对于 Windows ,可在安装 [TDM-GCC](https://jmeubank.github.io/tdm-gcc/) 时勾选 `openmp`。 Windows 测试环境 `gcc` 版本为 `TDM-GCC 10.3.0`, Linux 为 `gcc 11.3.0`。 ### 代码调用 @@ -60,9 +63,23 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进 如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。 ``` -完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b) 上查看。如果你从 Hugging Face Hub 上下载checkpoint的速度较慢,也可以从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载。 +### 从本地加载模型 +以上代码会由 `transformers` 自动下载模型实现和参数。完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b)。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。 -### Demo +从 Hugging Face Hub 下载模型需要先[安装Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行 +```Shell +git clone https://huggingface.co/THUDM/chatglm-6b +``` + +如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现 +```Shell +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b +``` +然后从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载模型参数文件,并将下载的文件替换到本地的 `chatglm-6b` 目录下。 + +将模型下载到本地之后,将以上代码中的 `THUDM/chatglm-6b` 替换为你本地的 `chatglm-6b` 文件夹的路径,即可从本地加载模型。 + +## Demo & API 我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库: @@ -95,14 +112,14 @@ python web_demo.py python cli_demo.py ``` -程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。 +程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。 ### API部署 -首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py): +首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py): ```shell python api.py ``` -默认部署在本地的8000端口,通过POST方法进行调用 +默认部署在本地的 8000 端口,通过 POST 方法进行调用 ```shell curl -X POST "http://127.0.0.1:8000" \ -H 'Content-Type: application/json' \ @@ -124,44 +141,35 @@ curl -X POST "http://127.0.0.1:8000" \ ```python # 按需修改,目前只支持 4/8 bit 量化 -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda() +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(4).half().cuda() ``` 进行 2 至 3 轮对话后,8-bit 量化下 GPU 显存占用约为 10GB,4-bit 量化下仅需 6GB 占用。随着对话轮数的增多,对应消耗显存也随之增长,由于采用了相对位置编码,理论上 ChatGLM-6B 支持无限长的 context-length,但总长度超过 2048(训练长度)后性能会逐渐下降。 模型量化会带来一定的性能损失,经过测试,ChatGLM-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。使用 [GPT-Q](https://arxiv.org/abs/2210.17323) 等量化方案可以进一步压缩量化精度/提升相同量化精度下的模型性能,欢迎大家提出对应的 Pull Request。 -**[2023/03/19]** 量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,仅需大概 5.2GB 的内存: +量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,仅需大概 5.2GB 的内存: ```python model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() ``` -**[2023/03/24]** 我们进一步提供了对Embedding量化后的模型,模型参数仅占用4.3 GB显存: -```python -model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4-qe", trust_remote_code=True).half().cuda() -``` - - - ### CPU 部署 如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存) ```python model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float() ``` -**[2023/03/19]** 如果你的内存不足,可以直接加载量化后的模型: +如果你的内存不足,可以直接加载量化后的模型: ```python model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float() ``` -如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) 的话请参考这个[Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041). +如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) ,请[从本地加载模型](README.md#从本地加载模型) ### Mac 上的 GPU 加速 -对于搭载了Apple Silicon的Mac(以及MacBook),可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。首先需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly。然后将模型仓库 clone 到本地 -```shell -git clone https://huggingface.co/THUDM/chatglm-6b -``` -将代码中的模型加载改为从本地加载,并使用 mps 后端 +对于搭载了Apple Silicon的Mac(以及MacBook),可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly。 + +目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端 ```python model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps') ``` @@ -178,6 +186,19 @@ from utils import load_model_and_tokenizer model, tokenizer = load_model_and_tokenizer("your local path", num_gpus=2) ``` 即可将模型部署到多卡上进行推理。 + +## 高效参数微调 +基于 [P-tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调。具体使用方法详见 [ptuning/README.md](ptuning/README.md)。 + +## 更新信息 +**[2023/04/06]** 优化web demo的界面(感谢 [@tuteng0915](https://github.com/tuteng0915))。移除embedding中的image token以减小显存占用(需要更新模型文件`pytorch_model-00001-of-00008.bin`和`pytorch_model-00008-of-00008.bin`,感谢 [@silverriver](https://github.com/silverriver) 提出的想法)。去掉了对 `icetk` 的依赖(需要更新模型文件`ice_text.model`)。 + +**[2023/03/31]** 增加基于 [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调实现,INT4 量化级别下最低只需 7GB 显存即可进行模型微调。详见[高效参数微调方法](ptuning/README.md)。 + +**[2023/03/23]** 增加 API 部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。增加 Embedding 量化模型 [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)。增加配备 Apple Silicon 芯片的 Mac 上 GPU 加速的支持。 + +**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4) + ## ChatGLM-6B 示例 以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现! diff --git a/README_en.md b/README_en.md index 63f407a..257776e 100644 --- a/README_en.md +++ b/README_en.md @@ -1,5 +1,13 @@ # ChatGLM-6B + +

+ 🌐 Blog • 🤗 HF Repo • 🐦 Twitter • 📃 [GLM@ACL 22] [GitHub] • 📃 [GLM-130B@ICLR 23] [GitHub]
+

+

+ 👋 Join our Slack and WeChat +

+ ## Introduction ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). @@ -9,13 +17,15 @@ ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dial Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces. ## Update -**[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe) +**[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details. + +**[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon. **[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4). ## Projects The following are some open source projects developed based on this repository: -* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An [MNN](https://github.com/alibaba/MNN)-based implementation of ChatGLM-6B C++ inference, which supports dynamic allocation of computing tasks to GPU and CPU according to the size of GPU memory +* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An [MNN](https://github.com/alibaba/MNN)-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory * [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): Fine-tuning ChatGLM-6B based on LoRA If you have other good projects, please refer to the above format to add to README and propose [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). @@ -32,7 +42,9 @@ If you have other good projects, please refer to the above format to add to READ ### Environment Setup -Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.26.1`, but theoretically any version no lower than `4.23.1` is acceptable. +Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.27.1`, but theoretically any version no lower than `4.23.1` is acceptable. + +In addition, if you need to run the quantified model on the CPU, you also need to install `gcc` and `openmp`. Most Linux distributions are installed by default. For Windows, you can check `openmp` when installing [TDM-GCC](https://jmeubank.github.io/tdm-gcc/). On Windows testing environment, the `gcc` version is `TDM-GCC 10.3.0`, and on Linux is `gcc 11.3.0`. ### Usage @@ -136,11 +148,6 @@ Model quantization brings a certain performance decline. After testing, ChatGLM- model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() ``` -**[2023/03/24]** We further provide an embedding-quantized model whose model parameters only cost 4.3GB GPU memory -```python -model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4-qe", trust_remote_code=True).half().cuda() -``` - ### CPU Deployment If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory): @@ -154,7 +161,23 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).fl model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float() ``` -**For Mac users**: if your encounter the error `RuntimeError: Unknown platform: darwin`, please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041). +If your encounter the error `Could not find module 'nvcuda.dll'` or `RuntimeError: Unknown platform: darwin`(MacOS), please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041). + +### GPU Inference on Mac +For Macs (and MacBooks) with Apple Silicon, it is possible to use the MPS backend to run ChatGLM-6B on the GPU. First, you need to refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly. Then clone the model repository locally (you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)) +```shell +git lfs install +git clone https://huggingface.co/THUDM/chatglm-6b +``` +Change the code to load the model from your local path, and use the mps backend: +```python +model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps') +``` +Then you can use GPU-accelerated model inference on Mac. + +## Parameter-efficient Tuning +Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it. + ### Multi-GPU Deployment diff --git a/api.py b/api.py index ab2066e..693c70a 100644 --- a/api.py +++ b/api.py @@ -1,10 +1,19 @@ -import datetime -import json - -import uvicorn from fastapi import FastAPI, Request +from transformers import AutoTokenizer, AutoModel +import uvicorn, json, datetime +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() -from utils import load_model_and_tokenizer app = FastAPI() @@ -17,7 +26,15 @@ async def create_item(request: Request): json_post_list = json.loads(json_post) prompt = json_post_list.get('prompt') history = json_post_list.get('history') - response, history = model.chat(tokenizer, prompt, history=history) + max_length = json_post_list.get('max_length') + top_p = json_post_list.get('top_p') + temperature = json_post_list.get('temperature') + response, history = model.chat(tokenizer, + prompt, + history=history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95) now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") answer = { @@ -28,10 +45,12 @@ async def create_item(request: Request): } log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' print(log) + torch_gc() return answer if __name__ == '__main__': - uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1) - -model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model.eval() + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/cli_demo.py b/cli_demo.py index fc48ca4..da80fff 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -1,12 +1,15 @@ import os import platform +import signal +from transformers import AutoTokenizer, AutoModel -from utils import load_model_and_tokenizer - -model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False def build_prompt(history): @@ -17,24 +20,35 @@ def build_prompt(history): return prompt +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + def main(): history = [] + global stop_stream print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: query = input("\n用户:") - if query == "stop": + if query.strip() == "stop": break - if query == "clear": + if query.strip() == "clear": history = [] os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue count = 0 for response, history in model.stream_chat(tokenizer, query, history=history): - count += 1 - if count % 8 == 0: - os.system(clear_command) - print(build_prompt(history), flush=True) + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) os.system(clear_command) print(build_prompt(history), flush=True) diff --git a/ptuning/README.md b/ptuning/README.md new file mode 100644 index 0000000..e3339ce --- /dev/null +++ b/ptuning/README.md @@ -0,0 +1,248 @@ +# ChatGLM-6B-PT +本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。 + +下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 + +*Read this in [English](README_en.md).* + +## 软件依赖 +运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖 +``` +pip install rouge_chinese nltk jieba datasets +``` +## 使用方法 + +### 下载数据集 +ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。 + +```json +{ + "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", + "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" +} +``` + +从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。 + +### 训练 + +#### P-tuning v2 + +运行以下指令进行训练: +```shell +bash train.sh +``` +`train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。 + +在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。 + +如果你想要[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。 + +#### Finetune + +如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令: + +```shell +bash ds_train_finetune.sh +``` + +### 推理 + +将 `evaluate.sh` 中的 `CHECKPOINT` 更改为训练时保存的 checkpoint 名称,运行以下指令进行模型推理和评测: +```shell +bash evaluate.sh +``` +**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定参数(已更新 `evaluate.sh`) : + +```shell +--model_name_or_path THUDM/chatglm-6b +--ptuning_checkpoint $CHECKPOINT_PATH +``` + +仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`: + +```shell +--model_name_or_path $CHECKPOINT_PATH +``` + +评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 +`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。 + +### 例子 +#### 示例1 +* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 +* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 +* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 +* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 + +#### 示例2 + +* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 +* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 +* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 +* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 + +### 评估结果 + +| | Finetune | P-tuning v2 | LoRA | +| ------------- | ----------- | ----- | ------------- | +| BLEU-4 | 8.01 | 8.10 | 7.62 | +| Rouge-1 | 31.23 | 31.12 | 30.60 | +| Rouge-2 | 7.36 | 7.11 | 6.96 | +| Rouge-l | 25.08 | 24.97 | 24.80 | +| Training Loss | 3.00 | 3.74 | 3.32 | + + + +#### 实验设置 + + ``` +max_source_length=64 +max_target_length=64 +max_steps=3000 + ``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +per_device_train_batch_size=16 +gradient_accumulation_steps=1 +``` + +##### Finetune + +``` +learning_rate=1e-4 +fp16 +num_gpus=4 +per_device_train_batch_size=4 +gradient_accumulation_steps=1 +``` + +##### LoRA + +实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + +``` +learning_rate=5e-4 +per_device_train_batch_size=16 +gradient_accumulation_steps=1 +``` + + + +## 模型部署 +首先载入Tokenizer: + +```python +import os +import torch +from transformers import AutoConfig, AutoModel, AutoTokenizer + +# 载入Tokenizer +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +``` + +1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数): + +```python +config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) +prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) +``` +注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。 + +2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint: + +```python +model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) +``` + +之后根据需求可以进行量化,也可以直接使用: + +```python +# Comment out the following line if you don't use quantization +model = model.quantize(4) +model = model.half().cuda() +model.transformer.prefix_encoder.float() +model = model.eval() + +response, history = model.chat(tokenizer, "你好", history=[]) +``` + +## 使用自己的数据集 +修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。 + +## 对话数据集 + +如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如 + +```json +{ + "prompt": "是的。上下水管都好的", + "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", + "history": [ + [ + "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", + "用电脑能读数据流吗?水温多少" + ], + [ + "95", + "上下水管温差怎么样啊?空气是不是都排干净了呢?" + ] + ] +} +``` + +训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接,例如: + +- Input + + ``` + [Round 0] + 问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线 + 答:用电脑能读数据流吗?水温多少 + [Round 1] + 问:95 + 答:上下水管温差怎么样啊?空气是不是都排干净了呢? + [Round 2] + 问:是的。上下水管都好的 + 答: + ``` + +- Label + + ``` + 那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况! + ``` + +要注意超过输入长度 `max_source_length` 的内容会被截。 + +可以参考以下指令: + +```shell +bash train_chat.sh +``` + +## 引用 + +``` +@inproceedings{liu2022p, + title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, + author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages={61--68}, + year={2022} +} +``` + + + diff --git a/ptuning/README_en.md b/ptuning/README_en.md new file mode 100644 index 0000000..9282da3 --- /dev/null +++ b/ptuning/README_en.md @@ -0,0 +1,115 @@ +# ChatGLM-6B-PT +This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run. + +The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code. + +## Software dependencies +Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required +``` +pip install rouge_chinese nltk jieba datasets +``` +## Instructions + +### Download the dataset +The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content). + +```json +{ + "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", + "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" +} +``` + +From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) Download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory. + +### Training +Run the following commands for training: +```shell +bash train.sh +``` +`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision. + +Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation. + +### Inference + +Change `CHECKPOINT` in `evaluate.sh` to the checkpoint name saved during training, and run the following commands for model inference and evaluation: +```shell +bash evaluate.sh +``` + +The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in +`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`. + +### Example +#### Example 1 +* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 +* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 +* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 +* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 + +#### Example 2 + +* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 +* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 +* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 +* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 + +### evaluation result + +| | P-tuning v2 | LoRA | +| ------- | ----------- | ----- | +| BLEU-4 | 7.71 | 6.13 | +| Rouge-1 | 31.35 | 28.36 | +| Rouge-2 | 7.19 | 4.38 | +| Rouge-l | 25.17 | 17.54 | + +#### Experiment Settings + + ``` +max_source_length=64 +max_target_length=64 +per_device_train_batch_size=1 +gradient_accumulation_steps=16 +max_steps=3000 + ``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +``` + +##### LoRA + +``` +learning_rate=5e-4 +``` + +The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + + + +## Model Deployment +Replace `THUDM/chatglm-6b` in the corresponding demo or code with the path of the checkpoint after P-Tuning(in the example, `./output/adgen-chatglm-6b-pt-8-1e-2/ checkpoint-3000`). Note that the current fine-tuning does not support multiple rounds of data, so only the responses from the first round of the conversation are fine-tuned. + +## Use your own dataset +Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text. + +## TODO +* [ ] Support for chat data +* [ ] Support for full finetuning + +## quoting + +``` +@inproceedings{liu2022p, + title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, + author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages={61--68}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptuning/arguments.py b/ptuning/arguments.py new file mode 100644 index 0000000..fda1f35 --- /dev/null +++ b/ptuning/arguments.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + ptuning_checkpoint: str = field( + default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + resize_position_embeddings: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to automatically resize the position embeddings if `max_source_length` exceeds " + "the model's position embeddings." + ) + }, + ) + quantization_bit: Optional[int] = field( + default=None + ) + pre_seq_len: Optional[int] = field( + default=None + ) + prefix_projection: bool = field( + default=False + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + prompt_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + response_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, + ) + history_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the history of chat."}, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": ( + "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." + ) + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": ( + "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + ) + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": ( + "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + ) + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + ) + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The token to force as the first generated token after the decoder_start_token_id." + "Useful for multilingual models like mBART where the first generated token" + "needs to be the target language token (Usually it is the target language token)" + ) + }, + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: + raise ValueError("Need either a dataset name or a training/validation/test file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + diff --git a/ptuning/deepspeed.json b/ptuning/deepspeed.json new file mode 100644 index 0000000..8e45509 --- /dev/null +++ b/ptuning/deepspeed.json @@ -0,0 +1,21 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients" : true + } +} \ No newline at end of file diff --git a/ptuning/ds_train_finetune.sh b/ptuning/ds_train_finetune.sh new file mode 100644 index 0000000..531a800 --- /dev/null +++ b/ptuning/ds_train_finetune.sh @@ -0,0 +1,28 @@ + +LR=1e-4 + +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \ + --deepspeed deepspeed.json \ + --do_train \ + --train_file AdvertiseGen/train.json \ + --test_file AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir ./output/adgen-chatglm-6b-ft-$LR \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --predict_with_generate \ + --max_steps 5000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --fp16 + diff --git a/ptuning/evaluate.sh b/ptuning/evaluate.sh new file mode 100644 index 0000000..ab85536 --- /dev/null +++ b/ptuning/evaluate.sh @@ -0,0 +1,21 @@ +PRE_SEQ_LEN=128 +CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 +STEP=3000 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_predict \ + --validation_file AdvertiseGen/dev.json \ + --test_file AdvertiseGen/dev.json \ + --overwrite_cache \ + --prompt_column content \ + --response_column summary \ + --model_name_or_path THUDM/chatglm-6b \ + --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \ + --output_dir ./output/$CHECKPOINT \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_eval_batch_size 1 \ + --predict_with_generate \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 diff --git a/ptuning/evaluate_finetune.sh b/ptuning/evaluate_finetune.sh new file mode 100644 index 0000000..e275c3c --- /dev/null +++ b/ptuning/evaluate_finetune.sh @@ -0,0 +1,18 @@ +CHECKPOINT=adgen-chatglm-6b-ft-1e-4 +STEP=3000 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_predict \ + --validation_file AdvertiseGen/dev.json \ + --test_file AdvertiseGen/dev.json \ + --overwrite_cache \ + --prompt_column content \ + --response_column summary \ + --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \ + --output_dir ./output/$CHECKPOINT \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_eval_batch_size 1 \ + --predict_with_generate \ + --fp16_full_eval diff --git a/ptuning/main.py b/ptuning/main.py new file mode 100644 index 0000000..43ecdf8 --- /dev/null +++ b/ptuning/main.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import os +import sys +import json + +import numpy as np +from datasets import load_dataset +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +import torch + +import transformers +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainingArguments, + set_seed, +) +from trainer_seq2seq import Seq2SeqTrainer + +from arguments import ModelArguments, DataTrainingArguments + +logger = logging.getLogger(__name__) + +def main(): + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + # datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Load dataset + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + + raw_datasets = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Load pretrained model and tokenizer + config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + + if model_args.ptuning_checkpoint is not None: + # Evaluation + # Loading extra state dict of prefix encoder + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + else: + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + + if model_args.quantization_bit is not None: + print(f"Quantized to {model_args.quantization_bit} bit") + model = model.quantize(model_args.quantization_bit) + if model_args.pre_seq_len is not None: + # P-tuning v2 + model = model.half() + model.transformer.prefix_encoder.float() + else: + # Finetune + model = model.float() + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = raw_datasets["train"].column_names + elif training_args.do_eval: + column_names = raw_datasets["validation"].column_names + elif training_args.do_predict: + column_names = raw_datasets["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # Get the column names for input/target. + prompt_column = data_args.prompt_column + response_column = data_args.response_column + history_column = data_args.history_column + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + + def preprocess_function_eval(examples): + inputs, targets = [], [] + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs.append(prompt) + targets.append(examples[response_column][i]) + + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) + labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) + + if data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def preprocess_function_train(examples): + max_seq_length = data_args.max_source_length + data_args.max_target_length + + model_inputs = { + "input_ids": [], + "labels": [], + } + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query, answer = examples[prompt_column][i], examples[response_column][i] + + if history_column is None: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > data_args.max_source_length - 1: + a_ids = a_ids[: data_args.max_source_length - 1] + + if len(b_ids) > data_args.max_target_length - 2: + b_ids = b_ids[: data_args.max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) + + context_length = input_ids.index(tokenizer.bos_token_id) + mask_position = context_length - 1 + labels = [-100] * context_length + input_ids[mask_position+1:] + + pad_len = max_seq_length - len(input_ids) + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + labels = labels + [tokenizer.pad_token_id] * pad_len + if data_args.ignore_pad_token_for_loss: + labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) + + return model_inputs + + def print_dataset_example(example): + print("input_ids",example["input_ids"]) + print("inputs", tokenizer.decode(example["input_ids"])) + print("label_ids", example["labels"]) + print("labels", tokenizer.decode(example["labels"])) + + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + preprocess_function_train, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + print_dataset_example(train_dataset[0]) + + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + print_dataset_example(eval_dataset[0]) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = raw_datasets["test"] + if data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + with training_args.main_process_first(desc="prediction dataset map pre-processing"): + predict_dataset = predict_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) + print_dataset_example(predict_dataset[0]) + + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=None, + padding=False + ) + + # Metric + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if data_args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + score_dict = { + "rouge-1": [], + "rouge-2": [], + "rouge-l": [], + "bleu-4": [] + } + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + training_args.generation_num_beams = ( + data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + ) + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + save_prefixencoder=model_args.pre_seq_len is not None + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + # elif last_checkpoint is not None: + # checkpoint = last_checkpoint + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + train_result = trainer.train(resume_from_checkpoint=checkpoint) + # trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95) + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Predict ***") + + predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95) + metrics = predict_results.metrics + max_predict_samples = ( + data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) + ) + metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + if trainer.is_world_process_zero(): + if training_args.predict_with_generate: + predictions = tokenizer.batch_decode( + predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + predictions = [pred.strip() for pred in predictions] + labels = tokenizer.batch_decode( + predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + labels = [label.strip() for label in labels] + output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") + with open(output_prediction_file, "w", encoding="utf-8") as writer: + for p, l in zip(predictions, labels): + res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) + writer.write(f"{res}\n") + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/ptuning/train.sh b/ptuning/train.sh new file mode 100644 index 0000000..efc9a16 --- /dev/null +++ b/ptuning/train.sh @@ -0,0 +1,26 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_train \ + --train_file AdvertiseGen/train.json \ + --validation_file AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 3000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 + diff --git a/ptuning/train_chat.sh b/ptuning/train_chat.sh new file mode 100644 index 0000000..b0f5cdc --- /dev/null +++ b/ptuning/train_chat.sh @@ -0,0 +1,27 @@ +PRE_SEQ_LEN=8 +LR=1e-2 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_train \ + --train_file $CHAT_TRAIN_DATA \ + --validation_file $CHAT_VAL_DATA \ + --prompt_column prompt \ + --response_column response \ + --history_column history \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir $CHECKPOINT_NAME \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 3000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 + diff --git a/ptuning/trainer.py b/ptuning/trainer.py new file mode 100644 index 0000000..63101bc --- /dev/null +++ b/ptuning/trainer.py @@ -0,0 +1,3830 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + +import contextlib +import functools +import glob +import inspect +import math +import os +import random +import re +import shutil +import sys +import time +import warnings +from collections.abc import Mapping +from distutils.util import strtobool +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from tqdm.auto import tqdm + + +# Integrations must be imported before ML frameworks: +# isort: off +from transformers.integrations import ( + default_hp_search_backend, + get_reporting_integration_callbacks, + hp_params, + is_fairscale_available, + is_optuna_available, + is_ray_tune_available, + is_sigopt_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) + +# isort: on + +import numpy as np +import torch +import torch.distributed as dist +from huggingface_hub import Repository, create_repo +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +from transformers import __version__ +from transformers.configuration_utils import PretrainedConfig +from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled +from transformers.dependency_versions_check import dep_version_check +from transformers.modelcard import TrainingSummary +from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES +from transformers.optimization import Adafactor, get_scheduler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + DistributedTensorGatherer, + IterableDatasetShard, + LabelSmoother, + LengthGroupedSampler, + SequentialDistributedSampler, + ShardSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_module_class_from_name, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_truncate, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + FSDPOption, + HPSearchBackend, + HubStrategy, + IntervalStrategy, + PredictionOutput, + RemoveColumnsCollator, + ShardedDDPOption, + TrainerMemoryTracker, + TrainOutput, + default_compute_objective, + default_hp_space, + denumpify_detensorize, + enable_full_determinism, + find_executable_batch_size, + get_last_checkpoint, + has_length, + number_of_arguments, + seed_worker, + set_seed, + speed_metrics, +) +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from transformers.utils import ( + CONFIG_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + can_return_loss, + find_labels, + get_full_repo_name, + is_accelerate_available, + is_apex_available, + is_datasets_available, + is_in_notebook, + is_ipex_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_compile_available, + is_torch_neuroncore_available, + is_torch_tpu_available, + logging, +) +from transformers.utils.generic import ContextManagers + + +_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +if is_in_notebook(): + from transformers.utils.notebook import NotebookProgressCallback + + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +if is_torch_tpu_available(check_device=False): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + +if is_fairscale_available(): + dep_version_check("fairscale") + import fairscale + from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP + from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP + from fairscale.nn.wrap import auto_wrap + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +skip_first_batches = None +if is_accelerate_available(): + from accelerate import __version__ as accelerate_version + + if version.parse(accelerate_version) >= version.parse("0.16"): + from accelerate import skip_first_batches + + +if TYPE_CHECKING: + import optuna + +logger = logging.get_logger(__name__) + + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" + + +class Trainer: + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + + Args: + model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): + The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. + + + + [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use + your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers + models. + + + + args ([`TrainingArguments`], *optional*): + The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the + `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. + data_collator (`DataCollator`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + default to [`default_data_collator`] if no `tokenizer` is provided, an instance of + [`DataCollatorWithPadding`] otherwise. + train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): + The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. + tokenizer ([`PreTrainedTokenizerBase`], *optional*): + The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the + maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an + interrupted training or reuse the fine-tuned model. + model_init (`Callable[[], PreTrainedModel]`, *optional*): + A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start + from a new instance of the model as given by this function. + + The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to + be able to choose different architectures according to hyper parameters (such as layer count, sizes of + inner layers, dropout probabilities etc). + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. + callbacks (List of [`TrainerCallback`], *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](callback). + + If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple + containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model + and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + + Important attributes: + + - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] + subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, + the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner + model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to `False` if model parallel or deepspeed is used, or if the default + `TrainingArguments.place_model_on_device` is overridden to return `False` . + - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while + in `train`) + + """ + + from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + save_prefixencoder: bool = False, + ): + self.save_prefixencoder = save_prefixencoder + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = TrainingArguments(output_dir=output_dir) + self.args = args + # Seed must be set before instantiating the model when using model + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + + # memory metrics - must set up as early as possible + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + + # set the correct log level depending on the node + log_level = args.get_process_log_level() + logging.set_verbosity(log_level) + + # force device and distributed setup init explicitly + args._setup_devices + + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" + " overwrite your model when calling the `train` method. This will become a fatal error in the next" + " release.", + FutureWarning, + ) + self.model_init = model_init + + if model.__class__.__name__ in MODEL_MAPPING_NAMES: + raise ValueError( + f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " + "computes hidden states and does not accept any labels. You should choose a model with a head " + "suitable for your task like any of the `AutoModelForXxx` listed at " + "https://huggingface.co/docs/transformers/model_doc/auto." + ) + + if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: + self.is_model_parallel = True + else: + self.is_model_parallel = False + + # At this stage the model is already loaded + if getattr(model, "is_loaded_in_8bit", False): + if getattr(model, "_is_int8_training_enabled", False): + logger.info( + "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + " inside the model such as adapters using `peft` library and freeze the model weights. Please" + " check " + " the examples in https://github.com/huggingface/peft for more details." + ) + else: + raise ValueError( + "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" + " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + ) + + # Setup Sharded DDP training + self.sharded_ddp = None + if len(args.sharded_ddp) > 0: + if args.deepspeed: + raise ValueError( + "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if len(args.fsdp) > 0: + raise ValueError( + "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." + ) + + if args.local_rank == -1: + raise ValueError("Using sharded DDP only works in distributed training.") + elif not is_fairscale_available(): + raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") + elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: + raise ImportError( + "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " + f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." + ) + elif ShardedDDPOption.SIMPLE in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.SIMPLE + elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 + elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 + + self.fsdp = None + if len(args.fsdp) > 0: + if args.deepspeed: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if not args.fsdp_config["xla"] and args.local_rank == -1: + raise ValueError("Using fsdp only works in distributed training.") + + # dep_version_check("torch>=1.12.0") + # Would have to update setup.py with torch>=1.12.0 + # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 + # below is the current alternative. + if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): + raise ValueError("FSDP requires PyTorch >= 1.12.0") + + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy + + if FSDPOption.FULL_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.FULL_SHARD + elif FSDPOption.SHARD_GRAD_OP in args.fsdp: + self.fsdp = ShardingStrategy.SHARD_GRAD_OP + elif FSDPOption.NO_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.NO_SHARD + + self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE + if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch: + self.backward_prefetch = BackwardPrefetch.BACKWARD_POST + + self.forword_prefetch = False + if self.args.fsdp_config.get("forword_prefect", False): + self.forword_prefetch = True + + self.limit_all_gathers = False + if self.args.fsdp_config.get("limit_all_gathers", False): + self.limit_all_gathers = True + + # one place to sort out whether to place the model on device or not + # postpone switching model to cuda when: + # 1. MP - since we are trying to fit a much bigger than 1 gpu model + # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, + # and we only use deepspeed for training at the moment + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first + # 4. Sharded DDP - same as MP + # 5. FSDP - same as MP + self.place_model_on_device = args.place_model_on_device + if ( + self.is_model_parallel + or args.deepspeed + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) + or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) + or (self.fsdp is not None) + ): + self.place_model_on_device = False + + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + + if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): + self._move_model_to_device(model, args.device) + + # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs + if self.is_model_parallel: + self.args._n_gpu = 1 + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + + self.compute_metrics = compute_metrics + self.preprocess_logits_for_metrics = preprocess_logits_for_metrics + self.optimizer, self.lr_scheduler = optimizers + if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): + raise RuntimeError( + "Passing a `model_init` is incompatible with providing the `optimizers` argument. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + if is_torch_tpu_available() and self.optimizer is not None: + for param in self.model.parameters(): + model_device = param.device + break + for param_group in self.optimizer.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + if model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you" + " created an optimizer around your model **before** putting on the device and passing it to the" + " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" + " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." + ) + if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( + self.optimizer is not None or self.lr_scheduler is not None + ): + raise RuntimeError( + "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + + # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. + self._loggers_initialized = False + + # Create clone of distant repo and output directory if needed + if self.args.push_to_hub: + self.init_git_repo(at_init=True) + # In case of pull, we need to make sure every process has the latest. + if is_torch_tpu_available(): + xm.rendezvous("init git repo") + elif args.local_rank != -1: + dist.barrier() + + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") + + if args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: + raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") + + if ( + train_dataset is not None + and isinstance(train_dataset, torch.utils.data.IterableDataset) + and args.group_by_length + ): + raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") + + self._signature_columns = None + + # Mixed precision setup + self.use_apex = False + self.use_cuda_amp = False + self.use_cpu_amp = False + + # Mixed precision setup for SageMaker Model Parallel + if is_sagemaker_mp_enabled(): + # BF16 + model parallelism in SageMaker: currently not supported, raise an error + if args.bf16: + raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") + + if IS_SAGEMAKER_MP_POST_1_10: + # When there's mismatch between SMP config and trainer argument, use SMP config as truth + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," + f"but FP16 provided in trainer argument is {args.fp16}," + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + else: + # smp < 1.10 does not support fp16 in trainer. + if hasattr(smp.state.cfg, "fp16"): + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." + ) + + if args.fp16 or args.bf16: + if args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + elif _is_native_cpu_amp_available: + args.half_precision_backend = "cpu_amp" + else: + raise ValueError("Tried to use cpu amp but native cpu amp is not available") + else: + args.half_precision_backend = "cuda_amp" + + logger.info(f"Using {args.half_precision_backend} half precision backend") + + self.do_grad_scaling = False + if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): + # deepspeed and SageMaker Model Parallel manage their own half precision + if args.half_precision_backend == "cuda_amp": + self.use_cuda_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + # bf16 does not need grad scaling + self.do_grad_scaling = self.amp_dtype == torch.float16 + if self.do_grad_scaling: + if self.sharded_ddp is not None: + self.scaler = ShardedGradScaler() + elif self.fsdp is not None: + from torch.distributed.fsdp.sharded_grad_scaler import ( + ShardedGradScaler as FSDPShardedGradScaler, + ) + + self.scaler = FSDPShardedGradScaler() + elif is_torch_tpu_available(): + from torch_xla.amp import GradScaler + + self.scaler = GradScaler() + else: + self.scaler = torch.cuda.amp.GradScaler() + elif args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + else: + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to" + " https://www.github.com/nvidia/apex." + ) + self.use_apex = True + + # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. + if ( + is_sagemaker_mp_enabled() + and self.use_cuda_amp + and args.max_grad_norm is not None + and args.max_grad_norm > 0 + ): + raise ValueError( + "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " + "along 'max_grad_norm': 0 in your hyperparameters." + ) + + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + + self.control = TrainerControl() + # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then + # returned to 0 every time flos need to be logged + self.current_flos = 0 + self.hp_search_backend = None + self.use_tune_checkpoints = False + default_label_names = find_labels(self.model.__class__) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(self.model.__class__) + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + # Internal variables to keep track of the original batch size + self._train_batch_size = args.train_batch_size + + # very last + self._memory_tracker.stop_and_update_metrics() + + # torch.compile + if args.torch_compile and not is_torch_compile_available(): + raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`~transformer.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + + def _move_model_to_device(self, model, device): + model = model.to(device) + # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. + if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): + model.tie_weights() + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if description is None else f"in the {description} set" + logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " + " you can safely ignore this message." + ) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def _get_collator_with_removed_columns( + self, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + if not self.args.remove_unused_columns: + return data_collator + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + logger=logger, + description=description, + model_name=self.model.__class__.__name__, + ) + return remove_columns_collator + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + generator = None + if self.args.world_size <= 1: + generator = torch.Generator() + # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with + # `args.seed`) if data_seed isn't provided. + # Further on in this method, we default to `args.seed` instead. + if self.args.data_seed is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.args.data_seed + generator.manual_seed(seed) + + seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed + + # Build the sampler. + if self.args.group_by_length: + if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): + lengths = ( + self.train_dataset[self.args.length_column_name] + if self.args.length_column_name in self.train_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + if self.args.world_size <= 1: + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + generator=generator, + ) + else: + return DistributedLengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + lengths=lengths, + model_input_name=model_input_name, + seed=seed, + ) + + else: + if self.args.world_size <= 1: + return RandomSampler(self.train_dataset, generator=generator) + elif ( + self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] + and not self.args.dataloader_drop_last + ): + # Use a loop for TPUs when drop_last is False to have all batches have the same size. + return DistributedSamplerWithLoop( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=seed, + ) + else: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=seed, + ) + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + if isinstance(train_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=self._train_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + train_sampler = self._get_train_sampler() + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) + + def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + # Deprecated code + if self.args.use_legacy_prediction_loop: + if is_torch_tpu_available(): + return SequentialDistributedSampler( + eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) + elif self.args.local_rank != -1: + return SequentialDistributedSampler(eval_dataset) + else: + return SequentialSampler(eval_dataset) + + if self.args.world_size <= 1: + return SequentialSampler(eval_dataset) + else: + return ShardSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + if isinstance(eval_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + eval_dataset = IterableDatasetShard( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + eval_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + + return DataLoader( + eval_dataset, + sampler=eval_sampler, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + data_collator = self.data_collator + + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + test_dataset = self._remove_unused_columns(test_dataset, description="test") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="test") + + if isinstance(test_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + test_dataset = IterableDatasetShard( + test_dataset, + batch_size=self.args.eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + test_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + test_sampler = self._get_eval_sampler(test_dataset) + + # We use the same batch_size as for eval. + return DataLoader( + test_dataset, + sampler=test_sampler, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_optimizer() + if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: + # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer + optimizer = self.optimizer.optimizer + else: + optimizer = self.optimizer + self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + print(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + print(f"skipped: {skipped/2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + @staticmethod + def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`transformers.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + if args.optim == OptimizerNames.ADAFACTOR: + optimizer_cls = Adafactor + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim == OptimizerNames.ADAMW_HF: + from transformers.optimization import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: + from torch.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: + optimizer_kwargs.update({"fused": True}) + elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: + try: + from torch_xla.amp.syncfree import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") + elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: + try: + from apex.optimizers import FusedAdam + + optimizer_cls = FusedAdam + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim == OptimizerNames.ADAMW_BNB: + try: + from bitsandbytes.optim import Adam8bit + + optimizer_cls = Adam8bit + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") + elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: + try: + from torchdistx.optimizers import AnyPrecisionAdamW + + optimizer_cls = AnyPrecisionAdamW + optimizer_kwargs.update(adam_kwargs) + + # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. + optimizer_kwargs.update( + { + "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), + "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), + "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), + "compensation_buffer_dtype": getattr( + torch, optim_args.get("compensation_buffer_dtype", "bfloat16") + ), + } + ) + except ImportError: + raise ValueError("Please install https://github.com/pytorch/torchdistx") + elif args.optim == OptimizerNames.SGD: + optimizer_cls = torch.optim.SGD + elif args.optim == OptimizerNames.ADAGRAD: + optimizer_cls = torch.optim.Adagrad + else: + raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") + return optimizer_cls, optimizer_kwargs + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + optimizer=self.optimizer if optimizer is None else optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + return self.lr_scheduler + + def num_examples(self, dataloader: DataLoader) -> int: + """ + Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When + dataloader.dataset does not exist or has no length, estimates as best it can + """ + try: + dataset = dataloader.dataset + # Special case for IterableDatasetShard, we need to dig deeper + if isinstance(dataset, IterableDatasetShard): + return len(dataloader.dataset.dataset) + return len(dataloader.dataset) + except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader + return len(dataloader) * self.args.per_device_train_batch_size + + def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): + """HP search setup code""" + self._trial = trial + + if self.hp_search_backend is None or trial is None: + return + if self.hp_search_backend == HPSearchBackend.OPTUNA: + params = self.hp_space(trial) + elif self.hp_search_backend == HPSearchBackend.RAY: + params = trial + params.pop("wandb", None) + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} + elif self.hp_search_backend == HPSearchBackend.WANDB: + params = trial + + for key, value in params.items(): + if not hasattr(self.args, key): + logger.warning( + f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" + " `TrainingArguments`." + ) + continue + old_attr = getattr(self.args, key, None) + # Casting value to the proper type + if old_attr is not None: + value = type(old_attr)(value) + setattr(self.args, key, value) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + logger.info(f"Trial: {trial.params}") + if self.hp_search_backend == HPSearchBackend.SIGOPT: + logger.info(f"SigOpt Assignments: {trial.assignments}") + if self.hp_search_backend == HPSearchBackend.WANDB: + logger.info(f"W&B Sweep parameters: {trial}") + if self.args.deepspeed: + # Rebuild the deepspeed config to reflect the updated training parameters + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) + self.args.hf_deepspeed_config.trainer_config_process(self.args) + + def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): + if self.hp_search_backend is None or trial is None: + return + self.objective = self.compute_objective(metrics.copy()) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + import optuna + + trial.report(self.objective, step) + if trial.should_prune(): + self.callback_handler.on_train_end(self.args, self.state, self.control) + raise optuna.TrialPruned() + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + if self.control.should_save: + self._tune_save_checkpoint() + tune.report(objective=self.objective, **metrics) + + def _tune_save_checkpoint(self): + from ray import tune + + if not self.use_tune_checkpoints: + return + with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: + output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") + self.save_model(output_dir, _internal_call=True) + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + + def call_model_init(self, trial=None): + model_init_argcount = number_of_arguments(self.model_init) + if model_init_argcount == 0: + model = self.model_init() + elif model_init_argcount == 1: + model = self.model_init(trial) + else: + raise RuntimeError("model_init should have 0 or 1 argument.") + + if model is None: + raise RuntimeError("model_init should not return None.") + + return model + + def torch_jit_model_eval(self, model, dataloader, training=False): + if not training: + if dataloader is None: + logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") + return model + example_batch = next(iter(dataloader)) + example_batch = self._prepare_inputs(example_batch) + try: + jit_model = model.eval() + with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): + if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): + if isinstance(example_batch, dict): + jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) + else: + jit_model = torch.jit.trace( + jit_model, + example_kwarg_inputs={key: example_batch[key] for key in example_batch}, + strict=False, + ) + else: + jit_inputs = [] + for key in example_batch: + example_tensor = torch.ones_like(example_batch[key]) + jit_inputs.append(example_tensor) + jit_inputs = tuple(jit_inputs) + jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) + jit_model = torch.jit.freeze(jit_model) + with torch.no_grad(): + jit_model(**example_batch) + jit_model(**example_batch) + model = jit_model + self.use_cpu_amp = False + self.use_cuda_amp = False + except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: + logger.warning(f"failed to use PyTorch jit mode due to: {e}.") + + return model + + def ipex_optimize_model(self, model, training=False, dtype=torch.float32): + if not is_ipex_available(): + raise ImportError( + "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" + " to https://github.com/intel/intel-extension-for-pytorch." + ) + + import intel_extension_for_pytorch as ipex + + if not training: + model.eval() + dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype + # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings + model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) + else: + if not model.training: + model.train() + model, self.optimizer = ipex.optimize( + model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" + ) + + return model + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.torch_compile: + model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) + + if self.args.use_ipex: + dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 + model = self.ipex_optimize_model(model, training, dtype=dtype) + + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + + # already initialized its own DDP and AMP + if self.deepspeed: + return self.deepspeed + + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if unwrap_model(model) is not model: + return model + + # Mixed precision training with apex (torch < 1.6) + if self.use_apex and training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) + if self.args.n_gpu > 1: + model = nn.DataParallel(model) + + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + if self.sharded_ddp is not None: + # Sharded DDP! + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + model = ShardedDDP(model, self.optimizer) + else: + mixed_precision = self.args.fp16 or self.args.bf16 + cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp + zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 + # XXX: Breaking the self.model convention but I see no way around it for now. + if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: + model = auto_wrap(model) + self.model = model = FullyShardedDDP( + model, + mixed_precision=mixed_precision, + reshard_after_forward=zero_3, + cpu_offload=cpu_offload, + ).to(self.args.device) + # Distributed training using PyTorch FSDP + elif self.fsdp is not None: + if not self.args.fsdp_config["xla"]: + # PyTorch FSDP! + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + + if FSDPOption.OFFLOAD in self.args.fsdp: + cpu_offload = CPUOffload(offload_params=True) + else: + cpu_offload = CPUOffload(offload_params=False) + + auto_wrap_policy = None + + if FSDPOption.AUTO_WRAP in self.args.fsdp: + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + mixed_precision_policy = None + dtype = None + if self.args.fp16: + dtype = torch.float16 + elif self.args.bf16: + dtype = torch.bfloat16 + if dtype is not None: + mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) + if type(model) != FSDP: + # XXX: Breaking the self.model convention but I see no way around it for now. + self.model = model = FSDP( + model, + sharding_strategy=self.fsdp, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + mixed_precision=mixed_precision_policy, + device_id=self.args.device, + backward_prefetch=self.backward_prefetch, + forward_prefetch=self.forword_prefetch, + limit_all_gathers=self.limit_all_gathers, + ) + else: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + return FSDP(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step + elif is_sagemaker_dp_enabled(): + model = nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] + ) + elif self.args.local_rank != -1: + kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + if is_torch_neuroncore_available(): + return model + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, + output_device=self.args.local_rank if self.args._n_gpu != 0 else None, + **kwargs, + ) + + return model + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs: + Additional keyword arguments used to hide deprecated arguments + """ + if resume_from_checkpoint is False: + resume_from_checkpoint = None + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None: + self._load_from_checkpoint(resume_from_checkpoint) + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self._train_batch_size = batch_size + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torch.distributed.launch)." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = ( + self.sharded_ddp is not None + and self.sharded_ddp != ShardedDDPOption.SIMPLE + or is_sagemaker_mp_enabled() + or self.fsdp is not None + ) + if args.deepspeed: + deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + elif not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState() + self.state.is_hyper_param_search = trial is not None + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + model = self._wrap_model(self.model_wrapped) + + if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + self._load_from_checkpoint(resume_from_checkpoint, model) + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + logger.info( + f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" + ) + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + if skip_first_batches is None: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," + " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" + " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" + " training on data already seen by your model." + ) + else: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: + steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) + steps_trained_progress_bar.set_description("Skipping the first batches") + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not args.ignore_data_skip: + for epoch in range(epochs_trained): + is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( + train_dataloader.sampler, RandomSampler + ) + if is_torch_less_than_1_11 or not is_random_sampler: + # We just need to begin an iteration to create the randomization of the sampler. + # That was before PyTorch 1.11 however... + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + _ = list(train_dataloader.sampler) + + total_batched_samples = 0 + for epoch in range(epochs_trained, num_train_epochs): + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch) + elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): + train_dataloader.dataset.set_epoch(epoch) + + if is_torch_tpu_available(): + parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) + epoch_iterator = parallel_loader + else: + epoch_iterator = train_dataloader + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if skip_first_batches is not None and steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + if ( + (total_batched_samples % args.gradient_accumulation_steps != 0) + and args.local_rank != -1 + and args._no_sync_in_gradient_accumulation + ): + # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. + with model.no_sync(): + tr_loss_step = self.training_step(model, inputs) + else: + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps + if self.deepspeed: + self.deepspeed.step() + + if total_batched_samples % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps + and (step + 1) == steps_in_epoch + ): + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + # deepspeed does its own clipping + + if self.do_grad_scaling: + # Reduce gradients first for XLA + if is_torch_tpu_available(): + gradients = xm._fetch_gradients(self.optimizer) + xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) + # AMP: gradients need unscaling + self.scaler.unscale_(self.optimizer) + + if is_sagemaker_mp_enabled() and args.fp16: + self.optimizer.clip_master_grads(args.max_grad_norm) + elif hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(args.max_grad_norm) + else: + # Revert to normal clipping otherwise, handling Apex or full precision + nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + args.max_grad_norm, + ) + + # Optimizer step + optimizer_was_run = True + if self.deepspeed: + pass # called outside the loop + elif is_torch_tpu_available(): + if self.do_grad_scaling: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + xm.optimizer_step(self.optimizer) + elif self.do_grad_scaling: + scale_before = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler.get_scale() + optimizer_was_run = scale_before <= scale_after + else: + self.optimizer.step() + + if optimizer_was_run and not self.deepspeed: + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + if step < 0: + logger.warning( + "There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_tpu_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sur the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.local_rank != -1: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if checkpoint != self.state.best_model_checkpoint: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def _get_output_dir(self, trial): + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + run_id = tune.get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + elif self.hp_search_backend == HPSearchBackend.WANDB: + import wandb + + run_id = wandb.run.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + return run_dir + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + if model is None: + model = self.model + + if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( + os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + ): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint}.") + + if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): + config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warning( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + + if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + # If the model is on the GPU, it still works! + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if hasattr(self.args, "fp16") and self.args.fp16 is True: + logger.warning( + "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." + ) + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # Required for smp to not auto-translate state_dict from hf to smp (is already smp). + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + # release memory + del state_dict + else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + # release memory + del state_dict + self._issue_warnings_after_load(load_result) + else: + # We load the sharded checkpoint + load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled()) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if os.path.exists(best_model_path): + if self.deepspeed: + if self.model_wrapped is not None: + # this removes the pre-hooks from the previous engine + self.model_wrapped.destroy() + self.model_wrapped = None + + # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping + deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + self, + num_training_steps=self.args.max_steps, + resume_from_checkpoint=self.state.best_model_checkpoint, + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + else: + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + state_dict = torch.load(best_model_path, map_location="cpu") + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + load_result = load_sharded_checkpoint( + model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _issue_warnings_after_load(self, load_result): + if len(load_result.missing_keys) != 0: + if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( + self.model._keys_to_ignore_on_save + ): + self.model.tie_weights() + else: + logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") + if len(load_result.unexpected_keys) != 0: + logger.warning( + f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log: + if is_torch_tpu_available(): + xm.mark_step() + + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + if isinstance(self.eval_dataset, dict): + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + if self.args.world_size > 1: + process_index = self.args.process_index + rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") + if not os.path.isfile(rng_file): + logger.info( + f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if torch.cuda.is_available(): + if self.args.local_rank != -1: + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + else: + try: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_tpu_available(): + xm.set_rng_state(checkpoint_rng_state["xla"]) + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + if self.deepspeed: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + self.deepspeed.save_checkpoint(output_dir) + + # Save optimizer and scheduler + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer.consolidate_state_dict() + + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + if self.args.should_save: + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + elif self.args.should_save and not self.deepspeed: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.local_rank == -1: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + + if self.args.world_size <= 1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if self.deepspeed: + # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + return + + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + if is_torch_tpu_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + with warnings.catch_warnings(record=True) as caught_warnings: + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): + # Optimizer checkpoint was saved with smp >= 1.10 + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + else: + # Optimizer checkpoint was saved with smp < 1.10 + def opt_load_hook(mod, opt): + if IS_SAGEMAKER_MP_POST_1_10: + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) + ) + else: + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) + + def hyperparameter_search( + self, + hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, + compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, + n_trials: int = 20, + direction: str = "minimize", + backend: Optional[Union["str", HPSearchBackend]] = None, + hp_name: Optional[Callable[["optuna.Trial"], str]] = None, + **kwargs, + ) -> BestRun: + """ + Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined + by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, + the sum of all metrics otherwise. + + + + To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to + reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to + subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom + optimizer/scheduler. + + + + Args: + hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): + A function that defines the hyperparameter search space. Will default to + [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or + [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. + compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): + A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` + method. Will default to [`~trainer_utils.default_compute_objective`]. + n_trials (`int`, *optional*, defaults to 100): + The number of trial runs to test. + direction (`str`, *optional*, defaults to `"minimize"`): + Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick + `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. + backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): + The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending + on which one is installed. If all are installed, will default to optuna. + hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): + A function that defines the trial/run name. Will default to None. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more + information see: + + - the documentation of + [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) + - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) + - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) + + Returns: + [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in + `run_summary` attribute for Ray backend. + """ + if backend is None: + backend = default_hp_search_backend() + if backend is None: + raise RuntimeError( + "At least one of optuna or ray should be installed. " + "To install optuna run `pip install optuna`. " + "To install ray run `pip install ray[tune]`. " + "To install sigopt run `pip install sigopt`." + ) + backend = HPSearchBackend(backend) + if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): + raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") + if backend == HPSearchBackend.RAY and not is_ray_tune_available(): + raise RuntimeError( + "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." + ) + if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): + raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") + if backend == HPSearchBackend.WANDB and not is_wandb_available(): + raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") + self.hp_search_backend = backend + if self.model_init is None: + raise RuntimeError( + "To use hyperparameter search, you need to pass your model through a model_init function." + ) + + self.hp_space = default_hp_space[backend] if hp_space is None else hp_space + self.hp_name = hp_name + self.compute_objective = default_compute_objective if compute_objective is None else compute_objective + + backend_dict = { + HPSearchBackend.OPTUNA: run_hp_search_optuna, + HPSearchBackend.RAY: run_hp_search_ray, + HPSearchBackend.SIGOPT: run_hp_search_sigopt, + HPSearchBackend.WANDB: run_hp_search_wandb, + } + best_run = backend_dict[backend](self, n_trials, direction, **kwargs) + + self.hp_search_backend = None + return best_run + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + if self.state.epoch is not None: + logs["epoch"] = round(self.state.epoch, 2) + + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)({k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": self.args.device} + if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) + return data.to(**kwargs) + return data + + def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and + handling potential state. + """ + inputs = self._prepare_input(inputs) + if len(inputs) == 0: + raise ValueError( + "The batch received was empty, your model won't be able to train on it. Double-check that your " + f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." + ) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs + + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + return self.autocast_smart_context_manager() + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.use_cuda_amp or self.use_cpu_amp: + if is_torch_greater_or_equal_than_1_10: + ctx_manager = ( + torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + if self.use_cpu_amp + else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + ) + else: + ctx_manager = torch.cuda.amp.autocast() + else: + ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() + + return ctx_manager + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.args.gradient_accumulation_steps + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def is_local_process_zero(self) -> bool: + """ + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several + machines) main process. + """ + return self.args.local_process_index == 0 + + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + """ + # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global + # process index. + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.args.process_index == 0 + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if is_torch_tpu_available(): + self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model_wrapped.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + if IS_SAGEMAKER_MP_POST_1_10: + # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 + Path(os.path.join(output_dir, "user_content.pt")).touch() + elif ( + ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp + or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + or self.fsdp is not None + ): + state_dict = self.model.state_dict() + + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + elif self.deepspeed: + # this takes care of everything as long as we aren't under zero3 + if self.args.should_save: + self._save(output_dir) + + if is_deepspeed_zero3_enabled(): + # It's too complicated to try to override different places where the weights dump gets + # saved, so since under zero3 the file is bogus, simply delete it. The user should + # either user deepspeed checkpoint to resume or to recover full weights use + # zero_to_fp32.py stored in the checkpoint. + if self.args.should_save: + file = os.path.join(output_dir, WEIGHTS_NAME) + if os.path.isfile(file): + # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") + os.remove(file) + + # now save the real model if stage3_gather_16bit_weights_on_model_save=True + # if false it will not be saved. + # This must be called on all ranks + if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME): + logger.warning( + "deepspeed.save_16bit_model didn't save the model, since" + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self.deepspeed.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") + + if xm.is_master_ordinal(): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + xm.rendezvous("saving_checkpoint") + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + unwrap_model(self.model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=self.model.state_dict(), + save_function=xm.save, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = self.model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + if self.tokenizer is not None and self.args.should_save: + self.tokenizer.save_pretrained(output_dir) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + if state_dict is None: + state_dict = self.model.state_dict() + unwrap_model(self.model).save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if state_dict is None: + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + if self.save_prefixencoder: + print("Saving PrefixEncoder") + state_dict = self.model.state_dict() + filtered_state_dict = {} + for k, v in self.model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k] + self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + print("Saving the whole model") + self.model.save_pretrained(output_dir, state_dict=state_dict) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def store_flos(self): + # Storing the number of floating-point operations that went into the model + if self.args.local_rank != -1: + self.state.total_flos += ( + distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() + ) + self.current_flos = 0 + else: + self.state.total_flos += self.current_flos + self.current_flos = 0 + + def _sorted_checkpoints( + self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False + ) -> List[str]: + ordering_and_checkpoint_path = [] + + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] + + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + # Make sure we don't delete the best model. + if self.state.best_model_checkpoint is not None: + best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] + return checkpoints_sorted + + def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return + + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def predict( + self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + + + If your predictions or labels have different sequence length (for instance because you're doing dynamic padding + in a token classification task) the predictions will be padded (on the right) to allow for concatenation into + one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + test_dataloader = self.get_test_dataloader(test_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"***** Running {description} *****") + if has_length(dataloader): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + + if args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + inputs_host = None + + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + all_inputs = None + # Will be useful when we have an iterable dataset so don't know its length. + + observed_num_examples = 0 + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + if is_torch_tpu_available(): + xm.mark_step() + + # Update containers on host + if loss is not None: + losses = self._nested_gather(loss.repeat(batch_size)) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if labels is not None: + labels = self._pad_across_processes(labels) + labels = self._nested_gather(labels) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_decode = self._pad_across_processes(inputs_decode) + inputs_decode = self._nested_gather(inputs_decode) + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + if logits is not None: + logits = self._pad_across_processes(logits) + logits = self._nested_gather(logits) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode + if all_inputs is None + else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = ( + labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + ) + + # Set back to None to begin a new accumulation + losses_host, preds_host, inputs_host, labels_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + if num_samples == 0 and observed_num_examples > 0: + num_samples = observed_num_examples + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + if all_inputs is not None: + all_inputs = nested_truncate(all_inputs, num_samples) + + # Metrics! + if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if all_losses is not None: + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + if name is None: + name = "nested_gather" + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.local_rank != -1: + tensors = distributed_concat(tensors) + return tensors + + # Copied from Accelerate. + def _pad_across_processes(self, tensor, pad_index=-100): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so + they can safely be gathered. + """ + if isinstance(tensor, (list, tuple)): + return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + + if len(tensor.shape) < 2: + return tensor + # Gather all sizes + size = torch.tensor(tensor.shape, device=tensor.device)[None] + sizes = self._nested_gather(size).cpu() + + max_size = max(s[1] for s in sizes) + # When extracting XLA graphs for compilation, max_size is 0, + # so use inequality to avoid errors. + if tensor.shape[1] >= max_size: + return tensor + + # Then pad to the maximum size + old_size = tensor.shape + new_size = list(old_size) + new_size[1] = max_size + new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + new_tensor[:, : old_size[1]] = tensor + return new_tensor + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point + operations for every backward + forward pass. If using another model, either implement such a method in the + model or subclass and override this method. + + Args: + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + `int`: The number of floating-point operations. + """ + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) + else: + return 0 + + def init_git_repo(self, at_init: bool = False): + """ + Initializes a git repo in `self.args.hub_model_id`. + + Args: + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is + `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped + out. + """ + if not self.is_world_process_zero(): + return + if self.args.hub_model_id is None: + repo_name = Path(self.args.output_dir).absolute().name + else: + repo_name = self.args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) + + # Make sure the repo exists. + create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) + try: + self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + except EnvironmentError: + if self.args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(self.args.output_dir) + self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + else: + raise + + self.repo.git_pull() + + # By default, ignore the checkpoint folders + if ( + not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) + and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS + ): + with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + # Add "*.sagemaker" to .gitignore if using SageMaker + if os.environ.get("SM_TRAINING_ENV"): + self._add_sm_patterns_to_gitignore() + + self.push_in_progress = None + + def create_model_card( + self, + language: Optional[str] = None, + license: Optional[str] = None, + tags: Union[str, List[str], None] = None, + model_name: Optional[str] = None, + finetuned_from: Optional[str] = None, + tasks: Union[str, List[str], None] = None, + dataset_tags: Union[str, List[str], None] = None, + dataset: Union[str, List[str], None] = None, + dataset_args: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `List[str]`, *optional*): + Some tags to be included in the metadata of the model card. + model_name (`str`, *optional*): + The name of the model. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `List[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `List[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `List[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `List[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + if not self.is_world_process_zero(): + return + + training_summary = TrainingSummary.from_trainer( + self, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: + f.write(model_card) + + def _push_from_checkpoint(self, checkpoint_folder): + # Only push from one node. + if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: + return + # If we haven't finished the last push, we don't do this one. + if self.push_in_progress is not None and not self.push_in_progress.is_done: + return + + output_dir = self.args.output_dir + # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder + modeling_files = [CONFIG_NAME, WEIGHTS_NAME] + for modeling_file in modeling_files: + if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): + shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) + # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + # Same for the training arguments + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + try: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Temporarily move the checkpoint just saved for the push + tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") + # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a + # subfolder. + if os.path.isdir(tmp_checkpoint): + shutil.rmtree(tmp_checkpoint) + shutil.move(checkpoint_folder, tmp_checkpoint) + + if self.args.save_strategy == IntervalStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" + _, self.push_in_progress = self.repo.push_to_hub( + commit_message=commit_message, blocking=False, auto_lfs_prune=True + ) + finally: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Move back the checkpoint to its place + shutil.move(tmp_checkpoint, checkpoint_folder) + + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + + Parameters: + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`~Trainer.create_model_card`]. + + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of + the commit and an object to track the progress of the commit if `blocking=True` + """ + # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but + # it might fail. + if not hasattr(self, "repo"): + self.init_git_repo() + + model_name = kwargs.pop("model_name", None) + if model_name is None and self.args.should_save: + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + + # Needs to be executed on all processes for TPU training, but will only save on the processed determined by + # self.args.should_save. + self.save_model(_internal_call=True) + + # Only push from one node. + if not self.is_world_process_zero(): + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: + self.push_in_progress._process.kill() + self.push_in_progress = None + + git_head_commit_url = self.repo.push_to_hub( + commit_message=commit_message, blocking=blocking, auto_lfs_prune=True + ) + # push separately the model card to be independant from the rest of the model + if self.args.should_save: + self.create_model_card(model_name=model_name, **kwargs) + try: + self.repo.push_to_hub( + commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True + ) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + # + # Deprecated code + # + + def prediction_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + if not has_length(dataloader): + raise ValueError("dataloader must implement a working __len__") + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since + # for example the Z3-optimizer is a must for zero3 to work even for inference - what we + # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer + deepspeed_engine.optimizer.optimizer = None + deepspeed_engine.lr_scheduler = None + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = dataloader.batch_size + num_examples = self.num_examples(dataloader) + logger.info(f"***** Running {description} *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None + preds_host: Union[torch.Tensor, List[torch.Tensor]] = None + labels_host: Union[torch.Tensor, List[torch.Tensor]] = None + inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + + world_size = max(1, args.world_size) + + eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + if not prediction_loss_only: + # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass + # a batch size to the sampler) + make_multiple_of = None + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): + make_multiple_of = dataloader.sampler.batch_size + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + + model.eval() + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + + if args.past_index >= 0: + self._past = None + + self.callback_handler.eval_dataloader = dataloader + + for step, inputs in enumerate(dataloader): + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + if loss is not None: + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if logits is not None: + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + # Set back to None to begin a new accumulation + losses_host, preds_host, labels_host, inputs_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + eval_loss = eval_losses_gatherer.finalize() + preds = preds_gatherer.finalize() if not prediction_loss_only else None + label_ids = labels_gatherer.finalize() if not prediction_loss_only else None + inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None + + if self.compute_metrics is not None and preds is not None and label_ids is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if eval_loss is not None: + metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) + + def _gather_and_numpify(self, tensors, name): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.local_rank != -1: + tensors = distributed_concat(tensors) + + return nested_numpify(tensors) + + def _add_sm_patterns_to_gitignore(self) -> None: + """Add SageMaker Checkpointing patterns to .gitignore file.""" + # Make sure we only do this on the main process + if not self.is_world_process_zero(): + return + + patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] + + # Get current .gitignore content + if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): + with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: + current_content = f.read() + else: + current_content = "" + + # Add the patterns to .gitignore + content = current_content + for pattern in patterns: + if pattern not in content: + if content.endswith("\n"): + content += pattern + else: + content += f"\n{pattern}" + + # Write the .gitignore file if it has changed + if content != current_content: + with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: + logger.debug(f"Writing .gitignore file. Content: {content}") + f.write(content) + + self.repo.git_add(".gitignore") + + # avoid race condition with git status + time.sleep(0.5) + + if not self.repo.is_repo_clean(): + self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") + self.repo.git_push() diff --git a/ptuning/trainer_seq2seq.py b/ptuning/trainer_seq2seq.py new file mode 100644 index 0000000..19d5cf1 --- /dev/null +++ b/ptuning/trainer_seq2seq.py @@ -0,0 +1,247 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import Dataset + +from transformers.deepspeed import is_deepspeed_zero3_enabled +from trainer import Trainer +from transformers.trainer_utils import PredictionOutput +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Seq2SeqTrainer(Trainer): + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + **gen_kwargs + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + + gen_kwargs = gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + + return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def predict( + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "test", + **gen_kwargs + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + + + If your predictions or labels have different sequence lengths (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + + gen_kwargs = gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + + + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # XXX: adapt synced_gpus for fairscale as well + gen_kwargs = self._gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.model.config.max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams + ) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs["synced_gpus"] = ( + gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus + ) + + if "attention_mask" in inputs: + gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) + if "position_ids" in inputs: + gen_kwargs["position_ids"] = inputs.get("position_ids", None) + if "global_attention_mask" in inputs: + gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) + + # prepare generation inputs + # some encoder-decoder models can have varying encoder's and thus + # varying model input names + if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: + generation_inputs = inputs[self.model.encoder.main_input_name] + else: + generation_inputs = inputs[self.model.main_input_name] + + gen_kwargs["input_ids"] = generation_inputs + generated_tokens = self.model.generate(**gen_kwargs) + generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] + + # in case the batch is shorter than max length, the output should be padded + if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) + + loss = None + + if self.args.prediction_loss_only: + return (loss, None, None) + + if has_labels: + labels = inputs["labels"] + if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) + else: + labels = None + + return (loss, generated_tokens, labels) + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + ) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/requirements.txt b/requirements.txt index f311c85..fb8d79f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -protobuf>=3.19.5,<3.20.1 -transformers==4.26.1 -icetk +protobuf +transformers==4.27.1 cpm_kernels torch>=1.10 gradio +mdtex2html +sentencepiece accelerate \ No newline at end of file diff --git a/resources/WECHAT.md b/resources/WECHAT.md new file mode 100644 index 0000000..c9ee867 --- /dev/null +++ b/resources/WECHAT.md @@ -0,0 +1,7 @@ +
+ + +

扫码关注公众号,加入「ChatGLM交流群」

+

Scan the QR code to follow the official account and join the "ChatGLM Discussion Group"

+
+ diff --git a/resources/wechat.jpg b/resources/wechat.jpg new file mode 100644 index 0000000..8ea08f2 Binary files /dev/null and b/resources/wechat.jpg differ diff --git a/web_demo.py b/web_demo.py index 28790ec..4d9171b 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,44 +1,100 @@ import gradio as gr +import mdtex2html from utils import load_model_and_tokenizer model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) -MAX_TURNS = 20 -MAX_BOXES = MAX_TURNS * 2 +"""Override Chatbot.postprocess""" -def predict(input, max_length, top_p, temperature, history=None): - if history is None: - history = [] +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'
'
+            else:
+                lines[i] = f'
' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
"+line + text = "".join(lines) + return text + + +def predict(input, chatbot, max_length, top_p, temperature, history): + chatbot.append((parse_text(input), "")) for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - updates = [] - for query, response in history: - updates.append(gr.update(visible=True, value="用户:" + query)) - updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) - if len(updates) < MAX_BOXES: - updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - yield [history] + updates + chatbot[-1] = (parse_text(input), parse_text(response)) + + yield chatbot, history + + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return [], [] with gr.Blocks() as demo: - state = gr.State([]) - text_boxes = [] - for i in range(MAX_BOXES): - if i % 2 == 0: - text_boxes.append(gr.Markdown(visible=False, label="提问:")) - else: - text_boxes.append(gr.Markdown(visible=False, label="回复:")) + gr.HTML("""

ChatGLM

""") + chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( - container=False) + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - button = gr.Button("Generate") - button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) + + history = gr.State([]) + + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + demo.queue().launch(share=False, inbrowser=True) diff --git a/web_demo2.py b/web_demo2.py index 82fb033..ebc9b8a 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -19,7 +19,7 @@ MAX_TURNS = 20 MAX_BOXES = MAX_TURNS * 2 -def predict(input, history=None): +def predict(input, max_length, top_p, temperature, history=None): tokenizer, model = get_model() if history is None: history = [] @@ -33,7 +33,8 @@ def predict(input, history=None): message(input, avatar_style="big-smile", key=str(len(history)) + "_user") st.write("AI正在回复:") with st.empty(): - for response, history in model.stream_chat(tokenizer, input, history): + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): query, response = history[-1] st.write(response) @@ -47,6 +48,15 @@ prompt_text = st.text_area(label="用户命令输入", height = 100, placeholder="请在这儿输入您的命令") +max_length = st.sidebar.slider( + 'max_length', 0, 4096, 2048, step=1 +) +top_p = st.sidebar.slider( + 'top_p', 0.0, 1.0, 0.6, step=0.01 +) +temperature = st.sidebar.slider( + 'temperature', 0.0, 1.0, 0.95, step=0.01 +) if 'state' not in st.session_state: st.session_state['state'] = [] @@ -54,4 +64,4 @@ if 'state' not in st.session_state: if st.button("发送", key="predict"): with st.spinner("AI正在思考,请稍等........"): # text generation - st.session_state["state"] = predict(prompt_text, st.session_state["state"]) + st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) \ No newline at end of file diff --git a/web_demo_old.py b/web_demo_old.py new file mode 100644 index 0000000..88a6dc8 --- /dev/null +++ b/web_demo_old.py @@ -0,0 +1,45 @@ +from transformers import AutoModel, AutoTokenizer +import gradio as gr + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +MAX_TURNS = 20 +MAX_BOXES = MAX_TURNS * 2 + + +def predict(input, max_length, top_p, temperature, history=None): + if history is None: + history = [] + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + updates = [] + for query, response in history: + updates.append(gr.update(visible=True, value="用户:" + query)) + updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) + if len(updates) < MAX_BOXES: + updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) + yield [history] + updates + + +with gr.Blocks() as demo: + state = gr.State([]) + text_boxes = [] + for i in range(MAX_BOXES): + if i % 2 == 0: + text_boxes.append(gr.Markdown(visible=False, label="提问:")) + else: + text_boxes.append(gr.Markdown(visible=False, label="回复:")) + + with gr.Row(): + with gr.Column(scale=4): + txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( + container=False) + with gr.Column(scale=1): + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + button = gr.Button("Generate") + button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) +demo.queue().launch(share=False, inbrowser=True)