initial commit
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,171 @@
|
|||
# InternLM
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./doc/imgs/logo.svg" width="200"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">书生·浦语 官网</font></b>
|
||||
<sup>
|
||||
<a href="https://internlm.intern-ai.org.cn/">
|
||||
<i><font size="4">HOT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
<div> </div>
|
||||
</div>
|
||||
|
||||
[![license](./doc/imgs/license.svg)](https://github.com/open-mmlab/mmdetection/blob/main/LICENSE)
|
||||
|
||||
[📘使用文档](./doc/usage.md) |
|
||||
[🛠️安装教程](./doc/install.md) |
|
||||
[📊训练性能](./doc/train_performance.md) |
|
||||
[👀模型库](#model-zoo) |
|
||||
[🆕Update News](./CHANGE_LOG.md) |
|
||||
[🤔Reporting Issues](https://github.com/InternLM/InternLM/issues/new)
|
||||
|
||||
|
||||
[English](./README.md) |
|
||||
[简体中文](./README-zh-Hans.md)
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
## 简介
|
||||
InternLM ,即书生·浦语大模型,包含面向实用场景的70亿参数基础模型与对话模型 (InternLM-7B)。模型具有以下特点:
|
||||
- 使用上万亿高质量预料,建立模型超强知识体系;
|
||||
- 支持8k语境窗口长度,实现更长输入与更强推理体验;
|
||||
- 通用工具调用能力,支持用户灵活自助搭建流程;
|
||||
|
||||
提供了支持模型预训练的轻量级训练框架,无需安装大量依赖包,一套代码支持千卡预训练和单卡人类偏好对齐训练,同时实现了极致的性能优化,实现千卡训练下近90%加速效率。
|
||||
|
||||
## InternLM-7B
|
||||
|
||||
### 性能评测
|
||||
|
||||
我们使用开源评测工具 [OpenCompass](https://github.com/internLM/OpenCompass/) 从学科综合能力、语言能力、知识能力、推理能力、理解能力五大能力维度对InternLM开展全面评测,部分评测结果如下表所示,欢迎访问[ OpenCompass 榜单 ](https://opencompass.org.cn/rank)获取更多的评测结果。
|
||||
|
||||
| 数据集\模型 | **InternLM-Chat-7B** | **InternLM-7B** | LLaMA-7B | Baichuan-7B | ChatGLM2-6B | Alpaca-7B | Vicuna-7B |
|
||||
| -------------------- | --------------------- | ---------------- | --------- | --------- | ------------ | --------- | ---------- |
|
||||
| C-Eval(Val) | 53.2 | 53.4 | 24.2 | 42.7 | 50.9 | 28.9 | 31.2 |
|
||||
| MMLU | 50.8 | 51.0 | 35.2* | 41.5 | 46.0 | 39.7 | 47.3 |
|
||||
| AGIEval | 42.5 | 37.6 | 20.8 | 24.6 | 39.0 | 24.1 | 26.4 |
|
||||
| CommonSenseQA | 75.2 | 59.5 | 65.0 | 58.8 | 60.0 | 68.7 | 66.7 |
|
||||
| BUSTM | 74.3 | 50.6 | 48.5 | 51.3 | 55.0 | 48.8 | 62.5 |
|
||||
| CLUEWSC | 78.6 | 59.1 | 50.3 | 52.8 | 59.8 | 50.3 | 52.2 |
|
||||
| CommonSenseQA | 75.2 | 59.5 | 60.0 | 58.8 | 60.0 | 68.7 | 66.7 |
|
||||
| MATH | 6.4 | 7.1 | 2.8 | 3.0 | 6.6 | 2.2 | 2.8 |
|
||||
| GSM8K | 34.5 | 31.2 | 10.1 | 9.7 | 29.2 | 6.0 | 15.3 |
|
||||
| HumanEval | 14.0 | 10.4 | 14.0 | 9.2 | 9.2 | 9.2 | 11.0 |
|
||||
| RACE(High) | 76.3 | 57.4 | 46.9* | 28.1 | 66.3 | 40.7 | 54.0 |
|
||||
|
||||
- 以上评测结果基于 [OpenCompass 20230706](https://github.com/internLM/OpenCompass/) 获得(部分数据标注`*`代表数据来自原始论文),具体测试细节可参见 [OpenCompass](https://github.com/internLM/OpenCompass/) 中提供的配置文件。
|
||||
- 评测数据会因 [OpenCompass](https://github.com/internLM/OpenCompass/) 的版本迭代而存在数值差异,请以 [OpenCompass](https://github.com/internLM/OpenCompass/) 最新版的评测结果为主。
|
||||
|
||||
### Model Zoo
|
||||
当前通过 InternLM 训练的 InternLM 7B 和 InternLM 7B Chat 已经开源,我们提供两种格式的模型权重以供使用。除了使用 Transformers 格式加载模型之外,还可以通过 InternLM 加载以下格式的权重直接进行继续预训练或人类偏好对齐训练
|
||||
|
||||
| 模型 | InternLM 格式权重下载地址 | Transformers 格式权重下载地址 |
|
||||
| -------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------ |
|
||||
| **InternLM 7B** | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/InternLM-7b) | [🤗internlm/intern-7b](https://huggingface.co/internlm/internlm-7b) |
|
||||
| **InternLM Chat 7B** | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/InternLM-chat-7b) | [🤗internlm/intern-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
|
||||
|
||||
**局限性:** 尽管在训练过程中我们非常注重模型的安全性,尽力促使模型输出符合伦理和法律要求的文本,但受限于模型大小以及概率生成范式,模型可能会产生各种不符合预期的输出,例如回复内容包含偏见、歧视等有害内容,请勿传播这些内容。由于传播不良信息导致的任何后果,本项目不承担责任。
|
||||
|
||||
### 通过 Transformers 加载
|
||||
通过以下的代码加载 InternLM 7B Chat 模型
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
>>> model = AutoModel.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True, device='cuda')
|
||||
>>> model = model.eval()
|
||||
>>> response, history = model.chat(tokenizer, "你好", history=[])
|
||||
>>> print(response)
|
||||
你好!有什么我可以帮助你的吗?
|
||||
>>> response, history = model.chat(tokenizer, "请提供三个管理时间的建议。", history=history)
|
||||
>>> print(response)
|
||||
当然可以!以下是三个管理时间的建议:
|
||||
1. 制定计划:制定一个详细的计划,包括每天要完成的任务和活动。这将有助于您更好地组织时间,并确保您能够按时完成任务。
|
||||
2. 优先级:将任务按照优先级排序,先完成最重要的任务。这将确保您能够在最短的时间内完成最重要的任务,从而节省时间。
|
||||
3. 集中注意力:避免分心,集中注意力完成任务。关闭社交媒体和电子邮件通知,专注于任务,这将帮助您更快地完成任务,并减少错误的可能性。
|
||||
```
|
||||
|
||||
### 通过前端网页对话
|
||||
可以通过以下代码启动一个前端的界面来与 InternLM Chat 7B 模型进行交互
|
||||
```bash
|
||||
pip install streamlit==1.24.0
|
||||
pip install transformers==4.30.2
|
||||
streamlit run web_demo.py
|
||||
```
|
||||
效果如下
|
||||
|
||||
![效果](https://github.com/InternLM/InternLM/assets/9102141/08ec4541-9126-4d5f-b5c0-53947bc1d8bb)
|
||||
|
||||
### 基于InternLM高性能部署
|
||||
|
||||
我们使用 [LMDeploy](https://github.com/InternLM/LMDeploy) 完成 InternLM 的一键部署。
|
||||
|
||||
1. 首先安装 LMDeploy:
|
||||
|
||||
```
|
||||
python3 -m pip install lmdeploy
|
||||
```
|
||||
|
||||
2. 快速的部署命令如下:
|
||||
|
||||
```
|
||||
python3 -m lmdeploy.serve.turbomind.deploy InternLM-7B /path/to/internlm-7b/model hf
|
||||
```
|
||||
|
||||
3. 在导出模型后,你可以直接通过如下命令启动服务一个服务并和部署后的模型对话
|
||||
|
||||
```
|
||||
python3 -m lmdeploy.serve.client {server_ip_addresss}:33337
|
||||
```
|
||||
|
||||
[LMDeploy](https://github.com/InternLM/LMDeploy) 支持了 InternLM 部署的完整流程,请参考 [部署教程](https://github.com/InternLM/LMDeploy) 了解 InternLM 的更多部署细节。
|
||||
|
||||
## 微调&训练
|
||||
|
||||
### 预训练与微调使用教程
|
||||
请参考[使用教程](./doc/usage.md)开始InternLM的安装、数据处理、预训练与微调。
|
||||
|
||||
### 转换为 Transformers 格式使用
|
||||
通过 InternLM 进行训练的模型可以很轻松地转换为 HuggingFace Transformers 格式,方便与社区各种开源项目无缝对接。借助 `tools/convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式
|
||||
```bash
|
||||
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
|
||||
```
|
||||
转换之后可以通过以下的代码加载为 transformers
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
>>> model = AutoModel.from_pretrained("hf_ckpt/", trust_remote_code=True, device='cuda')
|
||||
```
|
||||
|
||||
|
||||
## 训练系统
|
||||
|
||||
### 系统结构
|
||||
请参考[系统结构文档](./doc/structure.md)进一步了解。
|
||||
|
||||
### 训练性能
|
||||
|
||||
InternLM 深度整合了 Flash-Attention, Apex 等高性能模型算子,提高了训练效率。通过构建 Hybrid Zero 技术,实现计算和通信的高效重叠,大幅降低了训练过程中的跨节点通信流量。InternLM 支持 7B 模型从 8 卡扩展到 1024 卡,千卡规模下加速效率可高达 90%,训练吞吐超过 180TFLOPS,平均单卡每秒处理的 token 数量超过3600。下表为 InternLM 在不同配置下的扩展性测试数据:
|
||||
|
||||
| GPU数量 | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 |
|
||||
| ---------------- | ---- | ---- | ---- | ---- | ----- | ----- | ----- | ------ |
|
||||
| TKS | 4078 | 3939 | 3919 | 3944 | 3928 | 3920 | 3835 | 3625 |
|
||||
| TFLOPS | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
|
||||
|
||||
TKS 代表平均每GPU每秒可以处理的 Token 数量。更多的性能测试数据可参考[训练性能文档](./doc/train_performance.md)进一步了解。
|
||||
|
||||
|
||||
## 贡献
|
||||
|
||||
我们感谢所有的贡献者为改进和提升 InternLM 所作出的努力。非常欢迎社区用户能参与进项目中来。请参考贡献指南来了解参与项目贡献的相关指引。
|
||||
|
||||
## 致谢
|
||||
|
||||
InternLM 代码库是一款由上海人工智能实验室和来自不同高校、企业的研发人员共同参与贡献的开源项目。我们感谢所有为项目提供新功能支持的贡献者,以及提供宝贵反馈的用户。 我们希望这个工具箱和基准测试可以为社区提供灵活高效的代码工具,供用户微调 InternLM 并开发自己的新模型,从而不断为开源社区提供贡献。特别鸣谢flash-attention (https://github.com/HazyResearch/flash-attention) 与 ColossalAI (https://github.com/hpcaitech/ColossalAI) 两项开源项目。
|
||||
|
||||
## 开源许可证
|
||||
|
||||
本仓库的代码依照 Apache-2.0 协议开源。InternLM 权重对学术研究完全开放,在获得官方的书面许可后,亦允许商业使用。申请商用许可与合作请联系 internlm@pjlab.org.cn。
|
|
@ -0,0 +1,177 @@
|
|||
# InternLM
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./doc/imgs/logo.svg" width="200"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">InternLM</font></b>
|
||||
<sup>
|
||||
<a href="https://internlm.intern-ai.org.cn/">
|
||||
<i><font size="4">HOT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
<div> </div>
|
||||
</div>
|
||||
|
||||
[![license](./doc/imgs/license.svg)](./LICENSE)
|
||||
[![evaluation](./doc/imgs/compass_support.svg)](https://github.com/internLM/OpenCompass/)
|
||||
|
||||
[📘Usage](./doc/en/usage.md) |
|
||||
[🛠️Installation](./doc/en/install.md) |
|
||||
[📊Train Performance](./doc/en/train_performance.md) |
|
||||
[👀Model](#model-zoo) |
|
||||
[🆕Update News](./CHANGE_LOG.md) |
|
||||
[🤔Reporting Issues](https://github.com/InternLM/InternLM/issues/new)
|
||||
|
||||
|
||||
[English](./README.md) |
|
||||
[简体中文](./README-zh-Hans.md)
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
||||
InternLM has open-sourced a 7 billion parameter base model and a chat model tailored for practical scenarios. The model has the following characteristics:
|
||||
- It leverages trillions of high-quality tokens for training to establish a powerful knowledge base.
|
||||
- It supports an 8k context window length, enabling longer input sequences and stronger reasoning capabilities.
|
||||
- It provides a versatile toolset for users to flexibly build their own workflows.
|
||||
|
||||
Additionally, a lightweight training framework is offered to support model pre-training without the need for extensive dependencies. With a single codebase, it supports pre-training on large-scale clusters with thousands of GPUs, and fine-tuning on a single GPU while achieving remarkable performance optimizations. InternLM achieves nearly 90% acceleration efficiency during training on 1024 GPUs.
|
||||
|
||||
## InternLM-7B
|
||||
|
||||
### Performance Evaluation
|
||||
|
||||
We conducted a comprehensive evaluation of InternLM using the open-source evaluation tool [OpenCompass](https://github.com/internLM/OpenCompass/). The evaluation covered five dimensions of capabilities: disciplinary competence, language competence, knowledge competence, inference competence, and comprehension competence. Here are some of the evaluation results, and you can visit the [OpenCompass leaderboard](https://opencompass.org.cn/rank) for more evaluation results.
|
||||
|
||||
| Datasets\Models | **InternLM-Chat-7B** | **InternLM-7B** | LLaMA-7B | Baichuan-7B | ChatGLM2-6B | Alpaca-7B | Vicuna-7B |
|
||||
| -------------------- | --------------------- | ---------------- | --------- | --------- | ------------ | --------- | ---------- |
|
||||
| C-Eval(Val) | 53.2 | 53.4 | 24.2 | 42.7 | 50.9 | 28.9 | 31.2 |
|
||||
| MMLU | 50.8 | 51.0 | 35.2* | 41.5 | 46.0 | 39.7 | 47.3 |
|
||||
| AGIEval | 42.5 | 37.6 | 20.8 | 24.6 | 39.0 | 24.1 | 26.4 |
|
||||
| CommonSenseQA | 75.2 | 59.5 | 65.0 | 58.8 | 60.0 | 68.7 | 66.7 |
|
||||
| BUSTM | 74.3 | 50.6 | 48.5 | 51.3 | 55.0 | 48.8 | 62.5 |
|
||||
| CLUEWSC | 78.6 | 59.1 | 50.3 | 52.8 | 59.8 | 50.3 | 52.2 |
|
||||
| CommonSenseQA | 75.2 | 59.5 | 60.0 | 58.8 | 60.0 | 68.7 | 66.7 |
|
||||
| MATH | 6.4 | 7.1 | 2.8 | 3.0 | 6.6 | 2.2 | 2.8 |
|
||||
| GSM8K | 34.5 | 31.2 | 10.1 | 9.7 | 29.2 | 6.0 | 15.3 |
|
||||
| HumanEval | 14.0 | 10.4 | 14.0 | 9.2 | 9.2 | 9.2 | 11.0 |
|
||||
| RACE(High) | 76.3 | 57.4 | 46.9* | 28.1 | 66.3 | 40.7 | 54.0 |
|
||||
|
||||
- The evaluation results were obtained from[OpenCompass 20230706](https://github.com/internLM/OpenCompass/) (some data marked with *, which menas come from the original papers), and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/internLM/OpenCompass/).
|
||||
- The evaluation data may have numerical differences due to the version iteration of [OpenCompass](https://github.com/internLM/OpenCompass/), so please refer to the latest evaluation results of [OpenCompass](https://github.com/internLM/OpenCompass/).
|
||||
|
||||
### Model Zoo
|
||||
InternLM 7B and InternLM 7B Chat, trained using InternLM, have been open-sourced. We provide two formats of model weights for use. In addition to loading the models using the Transformers format, you can also load the weights directly using InternLM for further pre-training or human preference alignment training.
|
||||
|
||||
| Model | InternLM Format Weight Download Link | Transformers Format Weight Download Link |
|
||||
| -------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------ |
|
||||
| **InternLM 7B** | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/InternLM-7b) | [🤗internlm/intern-7b](https://huggingface.co/internlm/internlm-7b) |
|
||||
| **InternLM Chat 7B** | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/InternLM-chat-7b) | [🤗internlm/intern-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) |
|
||||
|
||||
|
||||
**Limitations:** Although we have made efforts to ensure the safety of the model during the training process and to encourage the model to generate text that complies with ethical and legal requirements, the model may still produce unexpected outputs due to its size and probabilistic generation paradigm. For example, the generated responses may contain biases, discrimination, or other harmful content. Please do not propagate such content. We are not responsible for any consequences resulting from the dissemination of harmful information.
|
||||
|
||||
### Import from Transformers
|
||||
To load the InternLM 7B Chat model using Transformers, use the following code:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
>>> model = AutoModel.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True, device='cuda')
|
||||
>>> model = model.eval()
|
||||
>>> response, history = model.chat(tokenizer, "hello", history=[])
|
||||
>>> print(response)
|
||||
Hello! How can I help you today?
|
||||
>>> response, history = model.chat(tokenizer, "please provide three suggestions about time management", history=history)
|
||||
>>> print(response)
|
||||
Sure, here are three tips for effective time management:
|
||||
|
||||
1. Prioritize tasks based on importance and urgency: Make a list of all your tasks and categorize them into "important and urgent," "important but not urgent," and "not important but urgent." Focus on completing the tasks in the first category before moving on to the others.
|
||||
2. Use a calendar or planner: Write down deadlines and appointments in a calendar or planner so you don't forget them. This will also help you schedule your time more effectively and avoid overbooking yourself.
|
||||
3. Minimize distractions: Try to eliminate any potential distractions when working on important tasks. Turn off notifications on your phone, close unnecessary tabs on your computer, and find a quiet place to work if possible.
|
||||
|
||||
Remember, good time management skills take practice and patience. Start with small steps and gradually incorporate these habits into your daily routine.
|
||||
```
|
||||
|
||||
### Dialogue
|
||||
You can interact with the InternLM Chat 7B model through a frontend interface by running the following code:
|
||||
```bash
|
||||
pip install streamlit==1.24.0
|
||||
pip install transformers==4.30.2
|
||||
streamlit run web_demo.py
|
||||
```
|
||||
The effect is as follows
|
||||
|
||||
![demo](https://github.com/InternLM/InternLM/assets/9102141/08ec4541-9126-4d5f-b5c0-53947bc1d8bb)
|
||||
|
||||
### Deployment
|
||||
|
||||
We use [LMDeploy](https://github.com/InternLM/LMDeploy) to complete the one-click deployment of InternLM.
|
||||
|
||||
1. First, install LMDeploy:
|
||||
|
||||
```
|
||||
python3 -m pip install lmdeploy
|
||||
```
|
||||
|
||||
2. Use the following command for quick deployment:
|
||||
|
||||
```
|
||||
python3 -m lmdeploy.serve.turbomind.deploy InternLM-7B /path/to/internlm-7b/model hf
|
||||
```
|
||||
|
||||
3. After exporting the model, you can start a server and have a conversation with the deployed model using the following command:
|
||||
|
||||
```
|
||||
python3 -m lmdeploy.serve.client {server_ip_addresss}:33337
|
||||
```
|
||||
|
||||
[LMDeploy](https://github.com/InternLM/LMDeploy) provides a complete workflow for deploying InternLM. Please refer to the [deployment tutorial](https://github.com/InternLM/LMDeploy) for more details on deploying InternLM.
|
||||
|
||||
## Fine-tuning & Training
|
||||
|
||||
### Pre-training and Fine-tuning Tutorial
|
||||
Please refer to [Usage Tutorial](./doc/en/usage.md) to start InternLM installation, data processing, pre-training and fine-tuning.
|
||||
|
||||
### Convert to Transformers Format
|
||||
The model trained by InternLM can be easily converted to HuggingFace Transformers format, which is convenient for seamless docking with various open source projects in the community. With the help of `tools/convert2hf.py`, the weights saved during training can be converted into transformers format with one command
|
||||
```bash
|
||||
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
|
||||
```
|
||||
After conversion, it can be loaded as transformers by the following code
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
>>> model = AutoModel.from_pretrained("hf_ckpt/", trust_remote_code=True, device='cuda')
|
||||
```
|
||||
|
||||
|
||||
## Training System
|
||||
|
||||
### System Architecture
|
||||
Please refer to the [System Architecture document](./doc/en/structure.md) for further details.
|
||||
|
||||
### Training Performance
|
||||
|
||||
InternLM deeply integrates Flash-Attention, Apex and other high-performance model operators to improve training efficiency. By building the Hybrid Zero technique, it achieves efficient overlap of computation and communication, significantly reducing cross-node communication traffic during training. InternLM supports expanding the 7B model from 8 GPUs to 1024 GPUs, with an acceleration efficiency of up to 90% at the thousand-GPU scale, a training throughput of over 180 TFLOPS, and an average of over 3600 tokens per GPU per second. The following table shows InternLM's scalability test data at different configurations:
|
||||
|
||||
| Number of GPUs | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 |
|
||||
| -------------- | ------ | ------- | ------- | ------- | -------- | -------- | -------- | --------- |
|
||||
| TGS | 4078 | 3939 | 3919 | 3944 | 3928 | 3920 | 3835 | 3625 |
|
||||
| TFLOPS | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
|
||||
|
||||
TGS represents the average number of tokens processed per GPU per second. For more performance test data, please refer to the [Training Performance document](./doc/en/train_performance.md) for further details.
|
||||
|
||||
|
||||
## Contribution
|
||||
|
||||
We appreciate all the contributors for their efforts to improve and enhance InternLM. Community users are highly encouraged to participate in the project. Please refer to the contribution guidelines for instructions on how to contribute to the project.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
InternLM codebase is an open-source project contributed by Shanghai AI Laboratory and researchers from different universities and companies. We would like to thank all the contributors for their support in adding new features to the project and the users for providing valuable feedback. We hope that this toolkit and benchmark can provide the community with flexible and efficient code tools for fine-tuning InternLM and developing their own models, thus continuously contributing to the open-source community. Special thanks to the two open-source projects, flash-attention (https://github.com/HazyResearch/flash-attention) and ColossalAI (https://github.com/hpcaitech/ColossalAI).
|
||||
|
||||
## Open Source License
|
||||
|
||||
The code in this repository is open-source under the Apache-2.0 license. The InternLM weights are fully open for academic research and also allow commercial use with written permission from the official team. For inquiries about commercial licenses and collaborations, please contact internlm@pjlab.org.cn.
|
|
@ -0,0 +1,129 @@
|
|||
JOB_NAME = "7b_train"
|
||||
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
# oss: 'boto3:s3://model_weights/XXX'
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
ckpt = dict(
|
||||
# Path to save training ckpt.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER,
|
||||
# Path to continue training ckpt (load model weights and scheduler/context states).
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER,
|
||||
# Path to initialize with given model weights.
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER,
|
||||
checkpoint_every=50,
|
||||
# Wheter to load optimizer states when continuing training.
|
||||
load_optimizer=True,
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=2,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
fp16=dict(
|
||||
# the initial loss scale, defaults to 2**16
|
||||
initial_scale=2**16,
|
||||
# the minimum loss scale, defaults to None
|
||||
min_scale=1,
|
||||
# the number of steps to increase loss scale when no overflow occurs
|
||||
growth_interval=1000,
|
||||
),
|
||||
# the multiplication factor for increasing loss scale, defaults to 2
|
||||
growth_factor=2,
|
||||
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||
backoff_factor=0.5,
|
||||
# the maximum loss scale, defaults to None
|
||||
max_scale=2**24,
|
||||
# the number of overflows before decreasing loss scale, defaults to 2
|
||||
hysteresis=2,
|
||||
)
|
||||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
zero_overlap_communication=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
|
||||
loss = dict(
|
||||
label_smoothing=0,
|
||||
)
|
||||
|
||||
adam = dict(
|
||||
lr=1e-4,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.95,
|
||||
adam_beta2_c=0,
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
lr_scheduler = dict(
|
||||
total_steps=data["total_steps"],
|
||||
init_steps=0, # optimizer_warmup_step
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
last_epoch=-1,
|
||||
)
|
||||
|
||||
beta2_scheduler = dict(
|
||||
init_beta2=adam["adam_beta2"],
|
||||
c=adam["adam_beta2_c"],
|
||||
cur_iter=-1,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||
so parameters will be divided within the range of dp.
|
||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel: pipeline parallel size, only 1 is accepted currently.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
cudnn_benchmark = False
|
|
@ -0,0 +1,58 @@
|
|||
## InternLM Installation
|
||||
|
||||
### Environment Preparation
|
||||
The required packages and corresponding version are shown as follows:
|
||||
- Python == 3.10
|
||||
- GCC == 10.2.0
|
||||
- MPFR == 4.1.0
|
||||
- CUDA == 11.7
|
||||
- Pytorch == 1.13.1+cu117
|
||||
- Transformers >= 4.25.1
|
||||
- Flash-Attention == 23.05
|
||||
- GPU with Ampere or Hopper architecture (such as H100, A100)
|
||||
- Linux OS
|
||||
|
||||
After installing the above dependencies, some system environment variables need to be updated:
|
||||
```bash
|
||||
export CUDA_PATH={path_of_cuda_11.7}
|
||||
export GCC_HOME={path_of_gcc_10.2.0}
|
||||
export MPFR_HOME={path_of_mpfr_4.1.0}
|
||||
export LD_LIBRARY_PATH=${GCC_HOME}/lib64:${MPFR_HOME}/lib:${CUDA_PATH}/lib64:$LD_LIBRARY_PATH
|
||||
export PATH=${GCC_HOME}/bin:${CUDA_PATH}/bin:$PATH
|
||||
export CC=${GCC_HOME}/bin/gcc
|
||||
export CXX=${GCC_HOME}/bin/c++
|
||||
```
|
||||
|
||||
### Environment Installation
|
||||
Clone the project `internlm` and its dependent submodules from the github repository, as follows:
|
||||
```bash
|
||||
git clone git@github.com:InternLM/InternLM.git --recurse-submodules
|
||||
```
|
||||
|
||||
It is recommended to build a Python-3.10 virtual environment using conda and install the required dependencies based on the `requirements/` files:
|
||||
```bash
|
||||
conda create --name internlm-env python=3.10 -y
|
||||
conda activate internlm-env
|
||||
cd internlm
|
||||
pip install -r requirements/torch.txt
|
||||
pip install -r requirements/runtime.txt
|
||||
```
|
||||
|
||||
Install flash-attention (version v1.0.5):
|
||||
```bash
|
||||
cd ./third_party/flash-attention
|
||||
python setup.py install
|
||||
cd ./csrc
|
||||
cd fused_dense_lib && pip install -v .
|
||||
cd ../xentropy && pip install -v .
|
||||
cd ../rotary && pip install -v .
|
||||
cd ../layer_norm && pip install -v .
|
||||
cd ../../../../
|
||||
```
|
||||
|
||||
Install Apex (version 23.05):
|
||||
```bash
|
||||
cd ./third_party/apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
cd ../../
|
||||
```
|
|
@ -0,0 +1,25 @@
|
|||
## InternLM System Structure
|
||||
The system code file structure is shown below:
|
||||
```bash
|
||||
├── configs # Configuration module, managing model and training-related parameters
|
||||
│ └── 7B_sft.py # 7B_sft.py is a sample configuration file for the system demo
|
||||
├── internlm # Main directory of the system code
|
||||
│ ├── apis # Interface module, containing some interface functions related to inference, etc.
|
||||
│ ├── core # Core module, managing parallel context and training scheduling engine for training and inference
|
||||
│ │ ├── context # Context module, mainly responsible for initializing parallel process groups and managing parallel context
|
||||
│ │ │ ├── parallel_context.py
|
||||
│ │ │ └── process_group_initializer.py
|
||||
│ │ ├── engine.py # Responsible for managing the training and evaluation process of the model
|
||||
│ │ ├── no_pipeline_scheduler.py # Scheduler for parallel training
|
||||
│ │ └── trainer.py # Responsible for managing the training engine and scheduler
|
||||
│ ├── data # Data module, responsible for managing dataset generation and processing
|
||||
│ ├── initialize # Initialization module, responsible for managing distributed environment startup and trainer initialization
|
||||
│ ├── model # Model module, responsible for managing model structure definition and implementation
|
||||
│ ├── solver # Responsible for managing the implementation of optimizer and lr_scheduler, etc.
|
||||
│ └── utils # Auxiliary module, responsible for managing logs, storage, model registration, etc.
|
||||
├── train.py # Main function entry file for model training
|
||||
├── requirements # List of dependent packages for system running
|
||||
├── third_party # Third-party modules on which the system depends, including apex and flash-attention, etc.
|
||||
├── tools # Some script tools for processing and converting raw datasets, model checkpoint conversion, etc.
|
||||
└── version.txt # System version number
|
||||
```
|
|
@ -0,0 +1,66 @@
|
|||
## Training Performance
|
||||
|
||||
|
||||
InternLM deeply integrates Flash-Attention, Apex, and other high-performance model operators to improve training efficiency. It achieves efficient overlap of computation and communication, significantly reducing cross-node communication traffic during training by building the Hybrid Zero technique. InternLM supports expanding the 7B model from 8 GPUs to 1024 GPUs, with an acceleration efficiency of up to 90% at the thousand-card scale, a training throughput of over 180 TFLOPS, and an average of over 3600 tokens per GPU per second. The following table shows InternLM's scalability test data at different configurations:
|
||||
|
||||
| GPU Number | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 |
|
||||
| ---------------- | ---- | ---- | ---- | ---- | ----- | ----- | ----- | ------ |
|
||||
| TGS (Tokens/GPU/Second) | 4078 | 3939 | 3919 | 3944 | 3928 | 3920 | 3835 | 3625 |
|
||||
| TFLOPS | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
|
||||
|
||||
|
||||
We tested the performance of training the 7B model in InternLM using various parallel configurations on a GPU cluster. In each test group, the number of tokens processed per GPU in a single iteration remained consistent. The hardware and parameter configurations used in the tests are shown in the table below:
|
||||
|
||||
| Hardware | Model |
|
||||
| ----------------------- | ----------------------------- |
|
||||
| GPU | nvidia_a100-sxm4-80gb |
|
||||
| Memory | 2TB |
|
||||
| Inter-machine bandwidth | 4 * 100Gb RoCE |
|
||||
| CPU | 128 core Intel(R) Xeon(R) CPU |
|
||||
|
||||
| Hyperparameters | tp=1 | tp=2 |
|
||||
| --------------- | ---- | ---- |
|
||||
| micro_num | 4 | 4 |
|
||||
| micro_bsz | 2 | 4 |
|
||||
| seq_len | 2048 | 2048 |
|
||||
|
||||
The configuration of `zero1` in InternLM determines the allocation range of optimizer states.
|
||||
- `zero1=-1` indicates that optimizer states are distributed across all data-parallel nodes (equivalent to Deepspeed Zero-1).
|
||||
- In the case of `zero1=8, tp=1`, optimizer states are distributed within 8 GPUs in a single node, and the optimizer states remain consistent across different nodes.
|
||||
|
||||
### Throughput Measurement
|
||||
|
||||
Throughput is defined as TGS, the average number of tokens processed per GPU per second. In this test, the training configuration had `pack_sample_into_one=False` and `checkpoint=False`. The test results are shown in the following table. When using `zero1=8, tp=1`, InternLM achieves an acceleration efficiency of `88%` for training the 7B model with a thousand cards.
|
||||
|
||||
| Parallel Configuration | 8 GPUs | 16 GPUs | 32 GPUs | 64 GPUs | 128 GPUs | 256 GPUs | 512 GPUs | 1024 GPUs |
|
||||
| ---------------------- | ------ | ------- | ------- | ------- | -------- | -------- | -------- | --------- |
|
||||
| (tp=1, zero1=-1) | 4062 | 3842 | 3752 | 3690 | 3571 | 3209 | 2861 | 2271 |
|
||||
| (tp=1, zero1=8) | 4078 | 3939 | 3919 | 3944 | 3928 | 3920 | 3835 | 3625 |
|
||||
| (tp=2, zero1=-1) | 3822 | 3595 | 3475 | 3438 | 3308 | 3094 | 2992 | 2785 |
|
||||
| (tp=2, zero1=4) | 3761 | 3658 | 3655 | 3650 | 3651 | 3653 | 3589 | 3486 |
|
||||
|
||||
<div align="left">
|
||||
<img src="../imgs/train_performance.png" width="580"/>
|
||||
</div>
|
||||
|
||||
|
||||
### FLOPS Testing
|
||||
|
||||
The computational workload of model training is based on the FLOPS calculation method described in the [Megatron](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) paper. To ensure constant FLOPS during training, the test configuration had `pack_sample_into_one=True`. The training used the following configuration:
|
||||
|
||||
Activation Checkpointing | tp | zero-1 | seq_len | micro_num | micro_bsz |
|
||||
| --- | --- | ------ | ------- | --------- | --------- |
|
||||
Disabled | 1 | 8 | 2048 | 4 | 2 |
|
||||
Enabled | 1 | 8 | 2048 | 1 | 8 |
|
||||
|
||||
The test results are shown in the table below. InternLM can achieve `>180 TFLOPS` for 7B model on thousand-card scale.
|
||||
|
||||
| Activation Checkpoint | 8 GPUs | 16 GPUs | 32 GPUs | 64 GPUs | 128 GPUs | 256 GPUs | 512 GPUs | 1024 GPUs |
|
||||
| --------------------- | ------ | ------- | ------- | ------- | -------- | -------- | -------- | --------- |
|
||||
| Disabled | 183 | 177 | 176 | 174 | 173 | 173 | 173 | 160 |
|
||||
| Enabled | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
|
||||
|
||||
<div align="left">
|
||||
<img src="../imgs/flops.png" width="580"/>
|
||||
</div>
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
## Pre-training and Fine-tuning Tutorial for InternLM
|
||||
|
||||
To start a demo model training, you need to prepare three things: **installation**, **dataset preparation**, and **model training configuration**. In this guide, we will first cover the steps for dataset preparation and then briefly describe the model training configuration.
|
||||
|
||||
### Installation
|
||||
|
||||
Please refer to the [installation guide](./install.md) for instructions on how to install the necessary dependencies.
|
||||
|
||||
### Dataset Preparation (Pre-training)
|
||||
|
||||
The dataset for InternLM training consists of a series of `bin` and `meta` files. To generate the training dataset, you need to use the `tokenizer` tool to tokenize the raw text data. The tokenizer model can be imported by specifying the model path in the `tools/tokenizer.py` script. The current provided model is `V7.model`. If you want to use a different model, you can modify the model path directly in the `tokenizer.py` script.
|
||||
|
||||
You can generate the `bin` and `meta` files for your raw data by running the following command, where the `raw_data_name` parameter represents the name of your raw data file, `input_file_type` represents the format of your raw data file (currently supports `txt`, `json`, and `jsonl`), and `bin` represents the path to save the generated `bin` files.
|
||||
|
||||
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
|
||||
```
|
||||
|
||||
Here is an example of data processing (only the data processing example for the `txt` format is provided here, the data processing process for `json` and `jsonl` is exactly the same as for `txt`):
|
||||
|
||||
Given a file `raw_data.txt` containing the raw dataset, the raw dataset is shown below:
|
||||
|
||||
```bash
|
||||
Appreciate every detail in life to truly taste the flavor of happiness.
|
||||
Dreams are the source of life’s motivation. Pursue them diligently to achieve your goals.
|
||||
Learn to be tolerant and understanding to establish truly harmonious interpersonal relationships.
|
||||
```
|
||||
|
||||
You can generate the `bin` and `meta` files by running the following command:
|
||||
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
|
||||
```
|
||||
|
||||
It should be noted that the generated `bin` files need to be saved in one of the following directories: `cn`, `en`, `code`, `ja`, `ar`, or `kaoshi`, depending on the type of dataset.
|
||||
|
||||
Here, `cn` represents the Chinese dataset, `en` represents the English dataset, `code` represents the code dataset, `ja` represents the Japanese dataset, `ar` represents the Arabic dataset, and `kaoshi` represents the exam dataset.
|
||||
|
||||
The format of the generated `bin` files is as follows:
|
||||
|
||||
```python
|
||||
{"tokens": [98655, 2317, 2922, 6649, 1595, 7856, 435, 2424, 442, 9556, 12807, 410, 17313, 446, 23331, 95746]}
|
||||
{"tokens": [98655, 302, 1383, 269, 657, 410, 2687, 446, 2424, 98667, 269, 25220, 281, 523, 1874, 492, 1248, 38127, 4563, 442, 11227, 829, 8980, 95746]}
|
||||
{"tokens": [98655, 24190, 442, 517, 15013, 649, 454, 8793, 442, 5849, 9556, 17917, 1369, 1084, 29890, 12021, 95746]}
|
||||
```
|
||||
Each line in the `bin` file corresponds to each sentence in the original dataset, representing the tokens of each sentence (referred to as sequence below).
|
||||
|
||||
The format of the generated `meta` file is as follows:
|
||||
|
||||
```bash
|
||||
(0, 16), (110, 24), (262, 17)
|
||||
```
|
||||
|
||||
Each tuple in the `meta` file represents the meta information of each `sequence`, where the first element in the tuple indicates the `starting index` of each `sequence` among all `sequences`, and the second element indicates the number of `tokens` for each `sequence`.
|
||||
|
||||
For example, the first `sequence` starts at index 0 and has 16 `tokens`. The second `sequence` starts at index 110 and has 24 `tokens`.
|
||||
|
||||
The `bin` and `meta` file formats for `json` and `jsonl` type files are the same as for `txt`, so we won't go over them here.
|
||||
|
||||
### Data Preparation (Fine-tuning)
|
||||
|
||||
The data format for fine-tuning tasks is the same as for pre-training tasks, which consists of a series of `bin` and `meta` files. Let's take the Alpaca dataset as an example to explain the data preparation process for fine-tuning.
|
||||
|
||||
1. Download the [Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json).
|
||||
|
||||
2. Tokenize the Alpaca dataset using the following command:
|
||||
|
||||
```shell
|
||||
python tools/alpaca_tokenizer.py /path/to/alpaca_dataset /path/to/output_dataset /path/to/tokenizer --split_ratio 0.1
|
||||
```
|
||||
|
||||
It is recommended that users refer to alpaca_tokenizer.py to write new scripts to tokenize their own datasets
|
||||
|
||||
### Training Configuration
|
||||
|
||||
Taking the configuration file `configs/7B_sft.py` for the 7B demo as an example, let's discuss the data, model, and parallel configurations required to start a model training.
|
||||
|
||||
#### Data Configuration
|
||||
Here are the key parameters and their explanations for data configuration:
|
||||
|
||||
```python
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
SEQ_LEN = 2048
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN, # Length of the data samples, default value is 2048
|
||||
micro_num=1, # Number of micro_batches processed in one model parameter update, default value is 1
|
||||
micro_bsz=1, # Packed_length = micro_bsz * SEQ_LEN, the size of data processed in one micro_batch, default value is 1
|
||||
total_steps=50000, # Total number of steps to be executed, default value is 50000
|
||||
min_length=50, # If the number of lines in the dataset file is less than 50, it will be discarded
|
||||
train_folder=TRAIN_FOLDER, # Dataset file path, default value is None; if train_folder is empty, training will be done using randomly generated datasets
|
||||
pack_sample_into_one=False, # Logic for data arrangement, determines whether to calculate attention based on the seq_len dimension or the actual length of the sequence
|
||||
)
|
||||
```
|
||||
|
||||
<div align="left">
|
||||
<img src="../imgs/pack_into_one.png" width="550"/>
|
||||
</div>
|
||||
|
||||
|
||||
Currently, it supports passing the dataset file path `train_folder`, and the file format is required to be as follows:
|
||||
|
||||
```bash
|
||||
- folder
|
||||
- code
|
||||
train_000.bin
|
||||
train_000.bin.meta
|
||||
```
|
||||
|
||||
For detailed information about the dataset, please refer to the "Data Preparation" section.
|
||||
|
||||
#### Model Configuration
|
||||
|
||||
If you want to load a model checkpoint when starting the training, you can configure it as follows:
|
||||
|
||||
```python
|
||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
|
||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
|
||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
|
||||
ckpt = dict(
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints
|
||||
checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf
|
||||
load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step
|
||||
load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time.
|
||||
- If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS.
|
||||
|
||||
The configuration for the model is as follows:
|
||||
|
||||
```python
|
||||
model_type = "INTERNLM" # Model type, default value is "INTERNLM", corresponding to the model structure initialization interface function
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
VOCAB_SIZE = 103168
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_LAYER = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
model = dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
)
|
||||
```
|
||||
|
||||
Note: Users can customize the model type name and model structure, and configure the corresponding model parameters. The model initialization function interface can be registered through the `MODEL_INITIALIZER` object in `utils/registry.py`. When initializing the model in the training main function `train.py`, the specified model initialization interface function can be obtained through the `model_type` configuration.
|
||||
|
||||
#### Parallel Configuration
|
||||
|
||||
Training parallel configuration example:
|
||||
|
||||
```python
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=1,
|
||||
tensor=1,
|
||||
)
|
||||
```
|
||||
|
||||
- zero1: zero parallel strategy, divided into the following three cases, default value is -1
|
||||
- When `size <= 0`, the size of the zero1 process group is equal to the size of the data parallel process group, so the optimizer state parameters will be split within the data parallel range.
|
||||
- When `size == 1`, zero1 is not used, and all data parallel groups retain the complete optimizer state parameters.
|
||||
- When `size > 1` and `size <= data_parallel_world_size`, the zero1 process group is a subset of the data parallel process group.
|
||||
- pipeline: pipeline parallel size, currently only supports 1, default value is 1
|
||||
- tensor: tensor parallel size, usually the number of GPUs per node, default value is 1
|
||||
|
||||
Note: `Data parallel size = Total number of GPUs / Pipeline parallel size / Tensor parallel size`
|
||||
|
||||
### Start Training
|
||||
|
||||
After completing the data preparation and relevant training configurations mentioned above, you can start the demo training. The following examples demonstrate how to start the training in both slurm and torch environments.
|
||||
|
||||
If you want to start distributed training on slurm with 16 GPUs across multiple nodes, use the following command:
|
||||
|
||||
```bash
|
||||
$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
|
||||
```
|
||||
|
||||
If you want to start distributed training on torch with 8 GPUs on a single node, use the following command:
|
||||
|
||||
```bash
|
||||
$ torchrun --nnodes=1 --nproc-per-node=8 train.py --config ./configs/7B_sft.py
|
||||
```
|
||||
|
||||
### Training Results
|
||||
|
||||
Taking the configuration of the demo training on a single machine with 8 GPUs on slurm as an example, the training result log is shown below:
|
||||
|
||||
```bash
|
||||
2023-07-04 21:40:14,148 INFO train.py:318 in record_current_batch_training_metrics -- step=17,loss=9.810295104980469,tgs (tokens per gpu per second)=4399.93,lr=3.8e-06,loss_scale=65536.0,grad_norm=4.177205427229359,micro_num=4,num_consumed_tokens=2359296,inf_nan_skip_batches=0,num_samples_in_batch=60,largest_length=1300,largest_batch=18,smallest_batch=13,adam_beta2=0.95,fwd_bwd_time=3.57
|
||||
2023-07-04 21:40:17,825 INFO train.py:318 in record_current_batch_training_metrics -- step=18,loss=9.715232849121094,tgs (tokens per gpu per second)=4457.7,lr=4.000000000000001e-06,loss_scale=65536.0,grad_norm=5.018154183978863,micro_num=4,num_consumed_tokens=2490368,inf_nan_skip_batches=0,num_samples_in_batch=68,largest_length=1153,largest_batch=19,smallest_batch=16,adam_beta2=0.95,fwd_bwd_time=3.52
|
||||
2023-07-04 21:40:21,526 INFO train.py:318 in record_current_batch_training_metrics -- step=19,loss=9.76744556427002,tgs (tokens per gpu per second)=4429.13,lr=4.2000000000000004e-06,loss_scale=65536.0,grad_norm=5.245329823265071,micro_num=4,num_consumed_tokens=2621440,inf_nan_skip_batches=0,num_samples_in_batch=70,largest_length=706,largest_batch=18,smallest_batch=17,adam_beta2=0.95,fwd_bwd_time=3.54
|
||||
2023-07-04 21:40:25,227 INFO train.py:318 in record_current_batch_training_metrics -- step=20,loss=9.628969192504883,tgs (tokens per gpu per second)=4427.46,lr=4.4e-06,loss_scale=65536.0,grad_norm=5.503176552110271,micro_num=4,num_consumed_tokens=2752512,inf_nan_skip_batches=0,num_samples_in_batch=69,largest_length=915,largest_batch=20,smallest_batch=15,adam_beta2=0.95,fwd_bwd_time=3.55
|
||||
2023-07-04 21:40:28,899 INFO train.py:318 in record_current_batch_training_metrics -- step=21,loss=9.690847396850586,tgs (tokens per gpu per second)=4464.18,lr=4.6e-06,loss_scale=65536.0,grad_norm=5.5336643273197526,micro_num=4,num_consumed_tokens=2883584,inf_nan_skip_batches=0,num_samples_in_batch=66,largest_length=870,largest_batch=17,smallest_batch=16,adam_beta2=0.95,fwd_bwd_time=3.52
|
||||
2023-07-04 21:40:32,629 INFO train.py:318 in record_current_batch_training_metrics -- step=22,loss=9.61986255645752,tgs (tokens per gpu per second)=4393.28,lr=4.800000000000001e-06,loss_scale=65536.0,grad_norm=9.01168869536059,micro_num=4,num_consumed_tokens=3014656,inf_nan_skip_batches=0,num_samples_in_batch=65,largest_length=1151,largest_batch=20,smallest_batch=14,adam_beta2=0.95,fwd_bwd_time=3.57
|
||||
```
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="159" height="20" role="img" aria-label="OpenCompass: Support"><title>OpenCompass: Support</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="159" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="106" height="20" fill="#555"/><rect x="106" width="53" height="20" fill="royalblue"/><rect width="159" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><image x="5" y="3" width="14" height="14" xlink:href=""/><text aria-hidden="true" x="625" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="790">OpenCompass</text><text x="625" y="140" transform="scale(.1)" fill="#fff" textLength="790">OpenCompass</text><text aria-hidden="true" x="1315" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="430">Support</text><text x="1315" y="140" transform="scale(.1)" fill="#fff" textLength="430">Support</text></g></svg>
|
After Width: | Height: | Size: 4.3 KiB |
After Width: | Height: | Size: 198 KiB |
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="120" height="20" role="img" aria-label="license: Apache-2.0"><title>license: Apache-2.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="120" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="47" height="20" fill="#555"/><rect x="47" width="73" height="20" fill="#97ca00"/><rect width="120" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="245" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="370">license</text><text x="245" y="140" transform="scale(.1)" fill="#fff" textLength="370">license</text><text aria-hidden="true" x="825" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="630">Apache-2.0</text><text x="825" y="140" transform="scale(.1)" fill="#fff" textLength="630">Apache-2.0</text></g></svg>
|
After Width: | Height: | Size: 1.1 KiB |
|
@ -0,0 +1,23 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="101" height="32" viewBox="0 0 101 32" fill="none">
|
||||
<path d="M93.8527 25.6137H93.3116C93.2804 25.6137 93.2544 25.6085 93.2232 25.6085H84.9801V20.2528H92.4793C92.5157 20.2528 92.5469 20.2476 92.5807 20.2476H93.0463C93.7616 20.2476 94.3391 19.6675 94.3391 18.9548V18.8429C94.3391 18.1302 93.759 17.5502 93.0463 17.5502H89.5374C89.5374 17.5502 89.5165 17.5528 89.5061 17.5528H84.9801V12.3349H92.4247C93.3247 12.3349 93.7747 11.9604 93.7747 11.2138V10.7092C93.7747 9.97569 93.3247 9.61154 92.4247 9.61154H84.9801V8.71935C84.9801 8.12369 84.5379 7.82715 83.6561 7.82715H81.5882C80.6882 7.82715 80.2382 8.20171 80.2382 8.94824V9.61154H77.2052L77.4523 9.06269L77.4809 8.92483C77.4809 8.81819 77.4341 8.72715 77.3431 8.64911C77.2338 8.57368 77.122 8.51125 77.0127 8.46703C76.8463 8.40721 76.7084 8.36819 76.5991 8.35258C76.4509 8.32137 76.3312 8.30056 76.2402 8.28495L74.4766 8.03265C74.3465 8.00143 74.2191 7.98583 74.0916 7.98583C73.4673 7.98583 73.0433 8.26155 72.8222 8.81038L70.8922 13.4326C70.8193 13.6017 70.7829 13.7525 70.7829 13.8904C70.7829 14.1349 70.8662 14.3248 71.03 14.4626C71.1965 14.6161 71.4514 14.7306 71.8026 14.806L73.5116 15.0817C73.6572 15.1129 73.7951 15.1285 73.9251 15.1285C74.513 15.1285 74.937 14.8684 75.1945 14.3508L76.2142 12.3375H80.2382V17.5554H75.0463C75.0463 17.5554 75.0254 17.5528 75.015 17.5528H71.5061C70.7907 17.5528 70.2133 18.1328 70.2133 18.8456V18.9574C70.2133 19.6727 70.7933 20.2502 71.5061 20.2502H72.0575C72.0913 20.2502 72.1225 20.2554 72.1589 20.2554H80.2382V25.6111H71.8832C71.852 25.6111 71.826 25.6163 71.7948 25.6163H71.2694C70.5566 25.6163 69.9766 26.1964 69.9766 26.9091V27.0209C69.9766 27.7362 70.5566 28.3137 71.2694 28.3137H71.5477C71.6517 28.3267 71.7636 28.3345 71.8832 28.3345H93.2518C93.3715 28.3345 93.4833 28.3267 93.5874 28.3137H93.8553C94.568 28.3137 95.1481 27.7336 95.1481 27.0209V26.9091C95.1481 26.1938 94.568 25.6163 93.8553 25.6163L93.8527 25.6137Z" fill="#858599"/>
|
||||
<path d="M61.2671 7.07092L61.433 7.03808C62.5054 6.82573 63.5469 7.52295 63.7593 8.59536L64.5515 12.5963C64.7638 13.6687 64.0666 14.7102 62.9942 14.9225L62.8283 14.9553C61.7559 15.1677 60.7144 14.4705 60.5021 13.398L59.7098 9.39715C59.4975 8.32475 60.1947 7.28326 61.2671 7.07092Z" fill="#858599"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M60.2579 17.8281H63.4287L63.4261 17.8333C64.3261 17.8333 64.7761 18.2079 64.7761 18.9544L64.7475 26.7344C64.7475 27.0856 64.6279 27.3613 64.3886 27.559C64.1492 27.7567 63.8163 27.8555 63.3949 27.8555L59.6284 27.8685C59.6137 27.8696 59.5993 27.8715 59.5851 27.8734C59.5642 27.8762 59.5434 27.8789 59.5218 27.8789H56.0778C55.3651 27.8789 54.7876 27.3015 54.7876 26.5887V26.4769C54.7876 25.7642 55.3651 25.1867 56.0778 25.1867H56.6033L56.6102 25.1861C56.6649 25.1811 56.7176 25.1763 56.7775 25.1763L58.2368 25.1867H59.5244C59.54 25.1867 59.5549 25.1893 59.5699 25.1919C59.5848 25.1945 59.5998 25.1971 59.6154 25.1971H59.6544C59.9848 25.1997 60.2527 24.9318 60.2527 24.6015V21.1732C60.2527 20.8428 59.9848 20.5775 59.657 20.5775H51.8327C51.5024 20.5775 51.237 20.8454 51.237 21.1758L51.2631 27.8555C51.2631 28.1884 51.1642 28.459 50.9639 28.6697C50.9457 28.6905 50.9223 28.7113 50.8989 28.7295C50.6648 28.909 50.3448 29 49.9417 29H48.0662C47.1662 29 46.7162 28.6254 46.7162 27.8789V21.1732C46.7162 20.8428 46.4483 20.5775 46.1205 20.5775H42.0549C41.9795 20.5775 41.9119 20.5723 41.8442 20.5671H41.298C40.5801 20.5671 40 19.9845 40 19.2692V19.1573C40 18.4394 40.5827 17.8593 41.298 17.8593H41.9535C41.9701 17.8593 41.986 17.8581 42.0023 17.8568C42.0192 17.8555 42.0364 17.8542 42.0549 17.8542H46.1205C46.4509 17.8542 46.7162 17.5862 46.7162 17.2585V12.3163C46.7162 11.986 46.4483 11.7207 46.1205 11.7207H42.0549C42.0432 11.7207 42.0322 11.72 42.0211 11.7194C42.0101 11.7187 41.999 11.7181 41.9873 11.7181H41.298C40.5801 11.7181 40 11.1354 40 10.4201V10.3083C40 9.59036 40.5827 9.01031 41.298 9.01031H41.8052C41.8858 9.00251 41.9691 8.99729 42.0575 8.99729H46.1232C46.4535 8.99729 46.7188 8.72418 46.7188 8.40164C46.7188 8.0791 46.8671 7.87361 47.161 7.78257C47.2885 7.73835 47.4732 7.71494 47.7125 7.71494H50.3605C50.9483 7.71494 51.2422 7.92823 51.2422 8.35482V8.40164C51.2422 8.72938 51.5076 8.99729 51.8353 8.99729L57.5553 9.01811C58.4553 9.01811 58.9053 9.38227 58.9053 10.1158V17.2559C58.9053 17.5862 59.1732 17.8516 59.501 17.8516H60.2579V17.8281ZM53.8616 17.8516C54.166 17.8516 54.4131 17.6044 54.4131 17.3001V12.2669C54.4131 11.9626 54.166 11.7155 53.8616 11.7155H51.7911C51.4868 11.7155 51.2396 11.9626 51.2396 12.2669V17.3001C51.2396 17.6044 51.4868 17.8516 51.7911 17.8516H53.8616Z" fill="#858599"/>
|
||||
<path d="M8.05859 27.1446C8.05859 27.1446 11.3184 26.7822 11.4996 26.7218C11.6808 26.6614 15.0615 26.8426 15.0615 26.8426L17.2342 27.3862L19.2257 28.111L19.7693 28.2922L22.7271 27.205L26.289 26.6614L28.7033 26.7218L30.5739 26.9634L31.3591 27.205L31.2987 30.4648L29.668 30.042L26.5289 29.8004L24.3562 29.9816L22.4855 30.4044L20.796 30.948L19.5881 31.4312L17.053 30.4648L13.9743 29.8608H10.4125L8.05859 30.2836V27.1446Z" fill="#858599" fill-opacity="0.5"/>
|
||||
<path d="M19.5787 31.9999C14.0908 28.9782 8.26229 30.7902 8.20357 30.8087L7.5459 31.0184V26.8576L7.89486 26.7452C8.1482 26.663 14.1847 24.7755 20.0636 28.0119L19.5787 28.8944C14.8609 26.2972 9.89306 27.272 8.55255 27.6075V29.6712C10.3041 29.2551 15.2367 28.4581 20.0636 31.1157L19.5787 31.9982V31.9999Z" fill="#858599"/>
|
||||
<path d="M19.5103 31.9999L19.0254 31.1174C23.8523 28.4599 28.7848 29.2568 30.5364 29.6729V27.6092C29.1976 27.272 24.2281 26.2989 19.5103 28.8961L19.0254 28.0136C24.9042 24.7772 30.9407 26.6647 31.1941 26.7469L31.543 26.8593V31.0218L30.8854 30.8087C30.8283 30.7902 24.9998 28.9783 19.5103 31.9999Z" fill="#858599"/>
|
||||
<path d="M25.8282 22.5107C25.8299 22.1466 25.7024 21.8346 25.444 21.5628C25.189 21.3044 24.8803 21.1819 24.5246 21.192C24.1606 21.2171 23.8502 21.3631 23.5884 21.6181C23.3334 21.9268 23.2093 22.249 23.211 22.6147C23.2093 22.9654 23.3334 23.2774 23.5834 23.5341C23.8384 23.8059 24.1505 23.9301 24.5196 23.9049C24.8786 23.9066 25.1873 23.7606 25.4407 23.4788C25.699 23.1969 25.8266 22.8731 25.8299 22.509L25.8282 22.5107Z" fill="#858599" fill-opacity="0.5"/>
|
||||
<path d="M15.6646 23.4805C15.923 23.1986 16.0505 22.8899 16.0488 22.5242C16.0505 22.1601 15.923 21.8212 15.6679 21.536C15.3945 21.3045 15.0891 21.1954 14.7469 21.2055C14.3711 21.2441 14.0607 21.3766 13.8107 21.6048C13.5557 21.8866 13.4282 22.2222 13.4248 22.6013C13.4248 22.9654 13.554 23.2775 13.809 23.5359C14.0556 23.8077 14.366 23.9301 14.7334 23.9066C15.1126 23.8949 15.4196 23.7624 15.6629 23.4805H15.6646Z" fill="#858599" fill-opacity="0.5"/>
|
||||
<path d="M19.4917 27.1262C18.4834 27.1262 17.4935 27.0423 16.7352 26.8477C14.695 26.3243 13.071 25.3797 11.7707 23.9586C10.4201 22.4805 10.3346 21.0561 10.244 19.5478C10.2322 19.3549 10.2205 19.1586 10.2054 18.959C10.0846 17.2896 9.29605 14.6807 9.28766 14.6555L9.22559 14.4509L9.35142 14.278C10.5124 12.6976 11.9435 11.5685 13.7303 10.8236C15.3712 10.139 17.3241 9.78168 19.7014 9.72967C23.6592 9.49814 27.9744 11.144 29.9642 13.6455L30.09 13.8032L29.1891 18.006C29.179 18.2325 29.1119 19.4153 28.7864 21.8329C28.4039 24.665 24.9058 26.2857 23.1089 26.7454C22.1962 26.9803 20.8255 27.1262 19.49 27.1262H19.4917ZM10.1685 14.6136C10.3631 15.283 10.9503 17.3953 11.0594 18.8952C11.0745 19.0982 11.0862 19.2995 11.098 19.4958C11.1852 20.9655 11.2557 22.1265 12.4016 23.3815C13.586 24.6767 15.0725 25.5408 16.9466 26.0206C18.5354 26.4283 21.3993 26.3025 22.8975 25.9166C24.1877 25.5861 27.6087 24.1851 27.9408 21.7155C28.2915 19.1167 28.3368 17.9439 28.3385 17.9322V17.8953L29.1706 14.0197C27.862 11.5064 27.0852 10.9913 25.4209 10.3639C23.6659 9.70115 21.5268 10.4729 19.7434 10.5786H19.7283C15.1816 10.6776 11.8479 9.21963 10.1668 14.6119L10.1685 14.6136Z" fill="#858599"/>
|
||||
<path d="M28.4299 14.3451C28.7538 15.1789 28.8913 15.5967 29.195 16.3601C29.195 16.3601 29.9986 15.0699 30.252 14.1471C30.507 13.2143 30.351 12.0264 30.1513 10.9158C29.6497 9.05179 29.2604 8.9075 27.7555 7.80354C27.0877 7.34551 26.3328 6.9395 25.4788 6.58214C23.6316 5.88587 21.776 5.49999 19.9087 5.41946C18.0581 5.38758 16.1908 5.66441 14.3201 6.24659C13.451 6.54858 12.6759 6.91098 11.9897 7.33041C10.4261 8.34209 9.3372 9.7665 8.73153 11.5952C8.56879 12.2831 8.31042 12.9442 8.37417 13.5817C8.6661 16.506 10.5737 17.7895 10.4428 17.0848C10.2734 16.1755 10.3824 15.6386 10.6156 14.7695C11.0586 13.1254 12.1072 11.9107 13.7631 11.1372H13.7681C14.0903 11.0466 15.7227 10.1188 17.8903 11.2278C18.4658 11.4745 19.4725 12.1422 19.9758 12.1271C20.4808 12.112 21.2072 11.6254 21.7357 11.2949C22.729 10.6725 23.6702 10.4594 24.9134 10.8537C26.7203 11.4224 27.8293 12.7076 28.4299 14.3468V14.3451Z" fill="#858599"/>
|
||||
<path d="M1.77058 12.8872L3.12955 13.2865L4.68818 13.206L5.80723 12.4459L7.00683 11.3672L7.88596 10.1676L9.00502 9.00824L10.0838 8.32875L11.5619 8.0083L10.562 9.76658L9.7231 10.6055L7.96482 13.0432L7.16621 14.2428L6.2468 14.882L4.9281 15.7612C4.9281 15.7612 3.76878 16.0011 3.64966 16.0414C3.53054 16.0816 2.57087 16.241 2.57087 16.241L1.73199 16.2007L1.09277 16.0011L1.69173 14.8015L1.77226 13.9224V12.8838L1.77058 12.8872Z" fill="#858599" fill-opacity="0.5"/>
|
||||
<path d="M5.24702 5.97485L5.12621 6.89258L4.80745 7.53348L4.40814 8.13243L3.4082 8.73139L4.6078 9.33035L5.9265 9.85045L7.28547 10.1306L7.88443 10.1709L9.32394 9.17096L10.8423 7.81199L10.3222 7.6526L8.4448 7.45295L7.04555 6.93285L6.04562 6.41442L5.24702 5.97485Z" fill="#858599" fill-opacity="0.5"/>
|
||||
<path d="M2.21392 16.5798C1.79113 16.5798 1.34653 16.5547 0.880114 16.506L0.000976562 16.4053L0.433838 16.1268L0.665362 15.8769C0.697239 15.8282 0.78952 15.6671 0.942194 15.3282V15.3249C1.18882 14.7947 1.27439 13.9072 1.19889 12.6875L1.14688 11.867L1.77939 12.3905C2.44042 12.9374 3.26923 13.0683 4.3111 12.7864C5.24225 12.5214 6.25561 11.6523 7.3193 10.2061C8.29743 8.79844 9.70506 7.8438 11.507 7.36732L11.6764 7.32202L12.8156 8.20955L11.7922 8.13573C11.6965 8.22801 11.6009 8.32699 11.5137 8.43605L11.4935 8.46121L11.4667 8.4847C10.8795 9.52826 10.0389 10.5584 9.19838 11.6858L9.18663 11.706C9.05745 11.9325 8.92491 12.1674 8.78062 12.4006V12.4123L8.73364 12.4878C7.87296 13.887 6.92671 14.9356 5.92007 15.6067C4.99563 16.2543 3.75242 16.5815 2.21392 16.5815V16.5798ZM1.48578 15.845C3.22225 15.9574 4.57619 15.6822 5.51572 15.0212L5.52243 15.0162C6.43345 14.4105 7.30085 13.4492 8.10114 12.1606C8.11624 12.117 8.13804 12.0801 8.16153 12.0516C8.30414 11.8217 8.43836 11.5852 8.56922 11.3553V11.352C9.13798 10.2614 9.7923 9.2917 10.5221 8.46121C9.43327 8.94105 8.5709 9.65409 7.90315 10.6154L7.89813 10.6238C6.72203 12.2261 5.61136 13.1572 4.50405 13.4726C3.52592 13.736 2.66691 13.6891 1.93877 13.3334C1.95387 14.3367 1.83978 15.09 1.59316 15.6218C1.55624 15.7057 1.52101 15.7795 1.48914 15.845H1.48578Z" fill="#858599"/>
|
||||
<path d="M2.29475 16.5799C1.87196 16.5799 1.42735 16.5547 0.960938 16.5061L1.03476 15.7997C3.02792 16.0094 4.56138 15.7477 5.59487 15.0213L5.60326 15.0162C6.51428 14.4106 7.38167 13.4492 8.18196 12.1607C8.21048 12.0818 8.25577 12.0282 8.29604 11.9929L8.76078 12.5298C8.83124 12.4677 8.86145 12.3838 8.86648 12.3251L8.85976 12.4107L8.81447 12.4862C7.95379 13.8854 7.00753 14.934 6.00089 15.6051C5.07645 16.2527 3.83324 16.5799 2.29475 16.5799Z" fill="#858599"/>
|
||||
<path d="M8.36647 9.76984C8.21547 9.77991 8.06447 9.7883 7.91347 9.79333C6.71724 9.83528 5.44048 9.5014 4.11003 8.80178C5.26264 8.05686 5.58309 7.20792 5.53779 6.44287C5.81294 6.66098 6.09983 6.85224 6.39847 7.01498C8.35472 8.11726 10.4217 8.29175 13.955 7.30356L13.7151 6.63413C10.358 7.56025 8.53927 7.40757 6.74241 6.39422C6.23237 6.11739 5.75925 5.74493 5.33646 5.29026L5.28781 5.23825L4.55966 4.44971L4.69724 5.46307C4.71737 5.54696 5.13513 7.52502 3.22754 8.48972L2.65039 8.78165L3.2074 9.11217C4.76771 10.04 6.28103 10.5097 7.70711 10.5097L8.36647 9.77152V9.76984Z" fill="#858599"/>
|
||||
<path d="M17.3537 20.6702C17.1322 20.6702 16.9242 20.5326 16.8218 20.301C16.347 19.2172 15.9746 18.7659 15.4125 18.0831C14.8437 17.3935 13.6324 16.2258 13.4328 16.0564C13.166 15.8299 13.0905 15.5211 13.2499 15.3081C13.3304 15.2024 13.5687 14.9893 14.0586 15.3148C15.0518 15.9725 17.2564 17.9505 17.852 19.5393C18.0701 20.1165 17.8519 20.5091 17.56 20.6282C17.4912 20.6567 17.4224 20.6685 17.352 20.6685L17.3537 20.6702Z" fill="#858599"/>
|
||||
<path d="M21.6405 20.6971C21.57 20.6971 21.4979 20.6836 21.4274 20.6534C21.1355 20.531 20.9241 20.135 21.1489 19.5612C21.7646 17.9808 23.9944 16.0312 24.996 15.3853C25.4909 15.0665 25.7258 15.2813 25.8046 15.3887C25.9624 15.6034 25.8835 15.9104 25.6134 16.1336C25.4121 16.3014 24.1856 17.4506 23.6085 18.1351C23.038 18.8113 22.6606 19.2575 22.1707 20.3363C22.0666 20.5645 21.8603 20.6987 21.6422 20.6987L21.6405 20.6971Z" fill="#858599"/>
|
||||
<path d="M23.8298 3.36916C23.8449 2.54539 23.4439 1.8357 22.6269 1.24178C21.7964 0.6596 20.8049 0.365991 19.6522 0.357603C18.4996 0.349214 17.5114 0.632753 16.6977 1.21158C15.879 1.79543 15.4612 2.49673 15.4512 3.3138C15.4545 3.56882 15.4763 3.80202 15.5082 4.02013C18.5634 3.03697 21.2998 3.1175 23.7124 4.2634C23.7778 3.95805 23.8181 3.66109 23.8298 3.36748V3.36916Z" fill="#858599" fill-opacity="0.5"/>
|
||||
<path d="M11.0557 19.9505L11.0406 19.9153C11.0003 19.823 10.9651 19.7291 10.9366 19.6368L10.851 19.3616L10.2856 19.3801L10.242 19.8549C10.2453 19.8733 10.2554 19.9186 10.2655 19.9992V20.016C10.2705 20.0445 10.2722 20.068 10.2772 20.0915C10.2856 20.3431 10.2101 20.5428 10.039 20.7189C9.87623 20.8917 9.66315 20.9689 9.38297 20.9572C9.03568 20.9437 8.91992 20.8045 8.86287 20.7005C8.7102 20.4253 8.70013 20.2777 8.70852 20.2122C8.73033 20.0327 8.76724 19.8431 8.81925 19.6519C8.89978 19.3532 8.83939 19.0059 8.63638 18.6217L8.42833 18.2275C8.34948 18.0782 8.28069 17.8902 8.22365 17.6688C8.18003 17.4826 8.19178 17.2594 8.25721 17.0027C8.31593 16.7762 8.44176 16.5984 8.64141 16.4574C8.84106 16.3232 9.0793 16.2746 9.36284 16.3115L9.38465 16.3148H9.40646C9.68497 16.3165 9.91314 16.3434 10.0876 16.3903C10.039 16.1185 10.0239 15.835 10.049 15.5464C9.86617 15.4693 9.68162 15.4156 9.49539 15.3904H9.48532L9.34774 15.3787C8.93334 15.3401 8.529 15.4323 8.14647 15.6538C7.75052 15.8853 7.48377 16.2276 7.35458 16.6722C7.24385 17.0598 7.23043 17.4591 7.31264 17.8567C7.3227 17.9171 7.33612 17.9876 7.35961 18.0547C7.42168 18.251 7.52403 18.4691 7.66496 18.704C7.87971 19.058 7.94515 19.1737 7.9636 19.2123L7.97031 19.2257C7.99715 19.2744 8.00219 19.3482 7.98373 19.4455L7.97702 19.4774C7.95688 19.5864 7.93676 19.6888 7.92837 19.7274C7.76059 20.1619 7.74381 20.5629 7.87636 20.9135C8.01058 21.286 8.24714 21.5578 8.58269 21.7256C8.77731 21.8212 9.00213 21.8816 9.25211 21.9068C9.38801 21.9202 9.53061 21.9219 9.67993 21.9135C10.1648 21.8883 10.5591 21.6585 10.8191 21.2541C11.0741 20.8699 11.1564 20.4438 11.0674 19.9891L11.0607 19.9505H11.0557Z" fill="#858599"/>
|
||||
<path d="M31.8916 16.085C31.7607 15.6454 31.4889 15.3048 31.0829 15.0716H31.0813C30.7172 14.8636 30.328 14.778 29.9219 14.8166C29.9068 14.8166 29.8917 14.82 29.8766 14.8216L29.7424 14.8384H29.7307C29.4723 14.8804 29.2139 14.9777 28.9623 15.1236C28.9639 15.1438 28.9673 15.1656 28.9706 15.1857C28.9941 15.4424 28.9824 15.6958 28.9388 15.9424L29.0847 15.8837C29.271 15.8082 29.5243 15.7461 29.8582 15.6958L29.8766 15.6924C30.1317 15.6404 30.3514 15.6874 30.5662 15.8434C30.7927 16.0078 30.9454 16.2125 31.0326 16.4658C31.1047 16.7041 31.1148 16.9088 31.0645 17.09C31.0024 17.3165 30.9319 17.5178 30.8548 17.6872C30.7625 17.8819 30.6685 18.0714 30.5763 18.2443C30.437 18.5026 30.3867 18.7828 30.4286 19.0747C30.4638 19.3298 30.5008 19.5412 30.5427 19.7224C30.5729 19.8482 30.5544 20.0092 30.489 20.2005L30.4857 20.2089C30.432 20.38 30.1434 20.4723 29.9756 20.5126C29.5931 20.6032 29.422 20.4639 29.3196 20.3381C29.1451 20.1133 29.0629 19.87 29.068 19.5965L29.0747 19.5177C29.0831 19.4388 29.0914 19.3985 29.0948 19.3834L29.0831 18.8935L28.4556 18.9086L28.3885 19.1771C28.365 19.2694 28.3348 19.36 28.2962 19.4506L28.2794 19.4891L28.2727 19.5311C28.1989 19.9774 28.2912 20.4018 28.5445 20.7944L28.5512 20.8028C28.828 21.2021 29.224 21.4219 29.6954 21.4387C29.8213 21.4454 29.9437 21.442 30.0595 21.432C30.338 21.4068 30.5863 21.333 30.7994 21.2139C31.1131 21.0411 31.3363 20.7642 31.4671 20.3901C31.5913 20.0227 31.5661 19.6234 31.3933 19.2023L31.3514 18.9539L31.348 18.9355C31.3279 18.8466 31.3363 18.7593 31.3698 18.6721C31.4353 18.516 31.5242 18.3349 31.6332 18.1285C31.7758 17.8768 31.8681 17.6638 31.9184 17.4826C31.9436 17.402 31.9604 17.3316 31.9654 17.2611C32.0309 16.8668 32.0057 16.4726 31.8882 16.0884L31.8916 16.085Z" fill="#858599"/>
|
||||
<path d="M24.1318 4.14939C24.1804 3.88095 24.2089 3.62761 24.219 3.37427C24.2358 2.43137 23.7811 1.61599 22.8651 0.94992C21.9809 0.329153 20.9122 0.010386 19.6891 0.000319559C18.4627 -0.0114247 17.399 0.300635 16.5249 0.921401C15.6105 1.57404 15.1407 2.37768 15.129 3.31051V3.3189C15.1307 3.50513 15.1424 3.69136 15.1659 3.88598C15.0803 3.98832 15.0334 4.10073 15.0283 4.22153V4.23495C15.0283 4.55037 15.0736 4.86578 15.1692 5.17113L15.1709 5.17952C15.3018 5.56541 15.5266 5.93283 15.8403 6.27173L15.9309 6.36905L16.7564 6.46132C17.461 6.87069 18.4023 7.09048 19.5599 7.11061C19.6102 7.11061 19.6606 7.11061 19.7092 7.11061C20.5632 7.11061 21.3098 7.00659 21.9306 6.8019L21.9473 6.7952C22.1805 6.7046 22.407 6.59554 22.6235 6.46971C23.2812 6.53179 23.7744 6.20798 24.0277 5.54863L24.0898 5.48152L24.0865 5.38756C24.1838 5.08892 24.3549 4.50339 24.1301 4.14939H24.1318ZM16.9359 1.50022C17.6842 0.968375 18.6069 0.703294 19.6824 0.711683C20.7578 0.720071 21.6906 0.996899 22.4507 1.52874C23.1738 2.05388 23.5194 2.65451 23.5076 3.35413C23.5026 3.48332 23.4925 3.61586 23.4741 3.75344C22.2527 3.24173 20.9491 2.98671 19.5683 2.98671C18.3821 2.98671 17.1406 3.17629 15.8487 3.55379C15.8437 3.47326 15.8404 3.39272 15.8387 3.31387C15.8487 2.61257 16.2078 2.01864 16.9359 1.50022ZM23.5009 4.83558L23.261 5.35401C23.261 5.35401 22.9473 5.807 22.605 5.74996L22.4725 5.72815L22.3584 5.79861C22.1504 5.9278 21.9289 6.03853 21.699 6.12745C21.1202 6.31704 20.4038 6.40764 19.5717 6.39589C18.5415 6.37576 17.7161 6.1912 17.1171 5.84391C17.0987 5.83217 17.0802 5.82042 17.0634 5.81203L16.993 5.76841L16.2782 5.68956C16.1423 5.53017 15.9494 5.19462 15.9494 5.19462L15.7531 4.62586C15.7531 4.62586 15.743 4.38427 15.7414 4.32891C18.6338 3.41789 21.2544 3.49339 23.5311 4.55372C23.5345 4.59567 23.5009 4.83558 23.5009 4.83558ZM23.714 5.36407V5.35904L23.719 5.36407H23.714Z" fill="#858599"/>
|
||||
</svg>
|
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 279 KiB |
After Width: | Height: | Size: 4.2 KiB |
After Width: | Height: | Size: 391 KiB |
After Width: | Height: | Size: 3.7 KiB |
|
@ -0,0 +1,58 @@
|
|||
## InternLM项目的依赖安装
|
||||
|
||||
### 环境准备
|
||||
首先,需要安装的依赖包及对应版本列表如下:
|
||||
- Python == 3.10
|
||||
- GCC == 10.2.0
|
||||
- MPFR == 4.1.0
|
||||
- CUDA == 11.7
|
||||
- Pytorch == 1.13.1+cu117
|
||||
- Transformers >= 4.25.1
|
||||
- Flash-Attention == 23.05
|
||||
- Ampere或者Hopper架构的GPU (例如H100, A100)
|
||||
- Linux OS
|
||||
|
||||
以上依赖包安装完成后,需要更新配置系统环境变量:
|
||||
```bash
|
||||
export CUDA_PATH={path_of_cuda_11.7}
|
||||
export GCC_HOME={path_of_gcc_10.2.0}
|
||||
export MPFR_HOME={path_of_mpfr_4.1.0}
|
||||
export LD_LIBRARY_PATH=${GCC_HOME}/lib64:${MPFR_HOME}/lib:${CUDA_PATH}/lib64:$LD_LIBRARY_PATH
|
||||
export PATH=${GCC_HOME}/bin:${CUDA_PATH}/bin:$PATH
|
||||
export CC=${GCC_HOME}/bin/gcc
|
||||
export CXX=${GCC_HOME}/bin/c++
|
||||
```
|
||||
|
||||
### 环境安装
|
||||
将项目`internlm`及其依赖子模块,从 github 仓库中 clone 下来,命令如下:
|
||||
```bash
|
||||
git clone git@github.com:InternLM/InternLM.git --recurse-submodules
|
||||
```
|
||||
|
||||
推荐使用 conda 构建一个 Python-3.10 的虚拟环境, 并基于`requirements/`文件安装项目所需的依赖包:
|
||||
```bash
|
||||
conda create --name internlm-env python=3.10 -y
|
||||
conda activate internlm-env
|
||||
cd internlm
|
||||
pip install -r requirements/torch.txt
|
||||
pip install -r requirements/runtime.txt
|
||||
```
|
||||
|
||||
安装 flash-attention (version v1.0.5):
|
||||
```bash
|
||||
cd ./third_party/flash-attention
|
||||
python setup.py install
|
||||
cd ./csrc
|
||||
cd fused_dense_lib && pip install -v .
|
||||
cd ../xentropy && pip install -v .
|
||||
cd ../rotary && pip install -v .
|
||||
cd ../layer_norm && pip install -v .
|
||||
cd ../../../../
|
||||
```
|
||||
|
||||
安装 Apex (version 23.05):
|
||||
```bash
|
||||
cd ./third_party/apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
cd ../../
|
||||
```
|
|
@ -0,0 +1,25 @@
|
|||
## InternLM系统结构
|
||||
本项目系统代码文件结构如下所示:
|
||||
```bash
|
||||
├── configs # 配置模块,管理模型和训练相关参数
|
||||
│ └── 7B_sft.py # 7B_sft.py 是系统 demo 的配置文件样例
|
||||
├── internlm # 系统代码的主目录
|
||||
│ ├── apis # 接口模块,包含一些关于推理等的接口函数
|
||||
│ ├── core # 核心模块,管理用于训练和推理的 parallel context 和训练调度引擎
|
||||
│ │ ├── context # context 模块,主要负责初始化并行进程组,并管理 parallel context
|
||||
│ │ │ ├── parallel_context.py
|
||||
│ │ │ └── process_group_initializer.py
|
||||
│ │ ├── engine.py # 负责管理模型的训练和评估过程
|
||||
│ │ ├── no_pipeline_scheduler.py # 并行训练的调度器
|
||||
│ │ └── trainer.py # 负责管理训练引擎和调度器
|
||||
│ ├── data # 数据模块,负责管理数据集生成和处理
|
||||
│ ├── initialize # 初始化模块,负责管理分布式环境启动和训练器初始化
|
||||
│ ├── model # 模型模块,负责管理模型结构定义和实现
|
||||
│ ├── solver # 负责管理 optimizer 和 lr_scheduler 等的实现
|
||||
│ └── utils # 辅助模块,负责管理日志、存储、模型注册等
|
||||
├── train.py # 模型训练的主函数入口文件
|
||||
├── requirements # 系统运行的依赖包列表
|
||||
├── third_party # 系统所依赖的第三方模块,包括 apex 和 flash-attention 等
|
||||
├── tools # 一些脚本工具,用于原始数据集处理和转换,模型 checkpoint 转换等
|
||||
└── version.txt # 系统版本号
|
||||
```
|
|
@ -0,0 +1,62 @@
|
|||
## 训练性能
|
||||
|
||||
InternLM 深度整合了 Flash-Attention, Apex 等高性能模型算子,提高了训练效率。通过构建 Hybrid Zero 技术,实现计算和通信的高效重叠,大幅降低了训练过程中的跨节点通信流量。InternLM 支持 7B 模型从 8 卡扩展到 1024 卡,千卡规模下加速效率可高达 90%,训练吞吐超过 180TFLOPS,平均单卡每秒处理的 token 数量超过3600。下表为 InternLM 在不同配置下的扩展性测试数据:
|
||||
|
||||
| InternLM | 8卡 | 16卡 | 32卡 | 64卡 | 128卡 | 256卡 | 512卡 | 1024卡 |
|
||||
| ---------------- | ---- | ---- | ---- | ---- | ----- | ----- | ----- | ------ |
|
||||
| TKS (Tokens/GPU/Second) | 4078 | 3939 | 3919 | 3944 | 3928 | 3920 | 3835 | 3625 |
|
||||
| TFLOPS | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
|
||||
|
||||
|
||||
我们在GPU集群上测试了多种并行配置下,InternLM训练7B模型的性能。在每组测试中,每张GPU在单次迭代中处理的token数量一致。测试使用的硬件和参数配置如下表所示:
|
||||
|
||||
| 硬件 | 硬件型号 |
|
||||
| ----------------------- | ----------------------------- |
|
||||
| GPU | nvidia_a100-sxm4-80gb |
|
||||
| Memory | 2TB |
|
||||
| Inter-machine bandwidth | 4 * 100Gb RoCE |
|
||||
| CPU | 128 core Intel(R) Xeon(R) CPU |
|
||||
|
||||
| 超参 | tp=1 | tp=2 |
|
||||
| --------- | ---- | ---- |
|
||||
| micro_num | 4 | 4 |
|
||||
| micro_bsz | 2 | 4 |
|
||||
| seq_len | 2048 | 2048 |
|
||||
|
||||
InternLM中`zero1`的配置决定了优化器状态的分配范围。
|
||||
- `zero1=-1`表明优化器状态分布在全部数据并行节点(等同于Deepspeed Zero-1的效果)
|
||||
- `zero1=8,tp=1`的情况下,优化器状态分布在单节点8张GPU内,并且不同节点上的优化器状态保持一致。
|
||||
|
||||
### 吞吐量测量
|
||||
|
||||
吞吐量定义为TGS,平均每GPU每秒处理的token的数量(Tokens per GPU per Second)。在该项测试的训练配置中,`pack_sample_into_one=False`,`checkpoint=False`。测试结果如下表所示。采用`zero1=8,tp=1`,InternLM针对7B模型训练的扩展性,在千卡训练的加速效率可以达到`88%`。
|
||||
|
||||
| 并行配置 | 8卡 | 16卡 | 32卡 | 64卡 | 128卡 | 256卡 | 512卡 | 1024卡 |
|
||||
| ---------------- | ---- | ---- | ---- | ---- | ----- | ----- | ----- | ------ |
|
||||
| (tp=1, zero1=-1) | 4062 | 3842 | 3752 | 3690 | 3571 | 3209 | 2861 | 2271 |
|
||||
| (tp=1, zero1=8) | 4078 | 3939 | 3919 | 3944 | 3928 | 3920 | 3835 | 3625 |
|
||||
| (tp=2, zero1=-1) | 3822 | 3595 | 3475 | 3438 | 3308 | 3094 | 2992 | 2785 |
|
||||
| (tp=2, zero1=4) | 3761 | 3658 | 3655 | 3650 | 3651 | 3653 | 3589 | 3486 |
|
||||
|
||||
|
||||
<div align="left">
|
||||
<img src="../doc/imgs/train_performance.png" width="580"/>
|
||||
</div>
|
||||
|
||||
### FLOPS测试
|
||||
模型训练的计算量参考 [Megatron](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) 论文中FLOPS计算方式。为了保证训练过程中的FLOPS恒定,在该项测试的训练配置中,`pack_sample_into_one=True`,其余超参设置如下所示:
|
||||
|
||||
activation checkpoint | tp | zero-1 | seq_len | micro_num | micro_bsz |
|
||||
| --- | --- | ---- | ---- | ---- |---- |
|
||||
关闭 | 1 | 8 | 2048 | 4 | 2 |
|
||||
开启 | 1 | 8 | 2048 | 1 | 8 |
|
||||
|
||||
测试结果如下表所示,InternLM针对7B模型的千卡训练,可以达到 `>180 TFLOPS`:
|
||||
| activation checkpoint | 8卡 | 16卡 | 32卡 | 64卡 | 128卡 | 256卡 | 512卡 | 1024卡 |
|
||||
| --------------- | --- | ---- | ---- | ---- | ----- | ----- | ----- | ------ |
|
||||
| 关闭 | 183 | 177 | 176 | 174 | 173 | 173 | 173 | 160 |
|
||||
| 开启 | 192 | 192 | 186 | 186 | 185 | 185 | 186 | 182 |
|
||||
|
||||
<div align="left">
|
||||
<img src="../doc/imgs/flops.png" width="580"/>
|
||||
</div>
|
|
@ -0,0 +1,191 @@
|
|||
## 基于InternLM的预训练与微调使用教程
|
||||
|
||||
启动一个 Demo 模型训练,需要进行三项准备,**安装**,**数据集准备**和**模型训练配置**。接下来,首先会介绍数据准备相关的操作,再简要描述模型训练配置相关的内容。
|
||||
|
||||
### 安装
|
||||
请参考[安装文档](./install.md)进行安装。
|
||||
|
||||
### 数据准备 (预训练)
|
||||
|
||||
InternLM训练任务的数据集包括一系列的`bin`和`meta`文件。使用`tokenizer`从原始文本文件生成训练用数据集。通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前提供`V7.model`来生成tokens。若想使用不同的模型,可直接修改`tokernizer.py`中的模型参数路径。
|
||||
|
||||
可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`raw_data_name`表示原始数据集的文件名称,`input_file_type`表示原始数据集的文件格式,目前支持`txt`、`json`和`jsonl`这三种格式,`bin`表示生成的`bin`文件的保存路径。
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
|
||||
```
|
||||
|
||||
下面是一个数据处理的例子(这里只给出了`txt`格式的数据处理例子,`json`和`jsonl`的数据处理流程和`txt`的完全一致):
|
||||
|
||||
给定一个包含原始数据集的文件`raw_data.txt`,原始数据集如下所示:
|
||||
```bash
|
||||
感恩生活中的每一个细节,才能真正体会到幸福的滋味。
|
||||
梦想是人生的动力源泉,努力追逐,才能实现自己的目标。
|
||||
学会宽容和理解,才能建立真正和谐的人际关系。
|
||||
```
|
||||
|
||||
可以通过运行以下命令来生成`bin`和`meta`文件:
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
|
||||
```
|
||||
|
||||
需要注意的是,生成的`bin`文件需要保存在`cn`或者`en`或者`code`或者`ja`或者`ar`或者`kaoshi`这六个目录下,以区分数据集的类型。
|
||||
|
||||
其中,`cn`表示中文数据集;`en`表示英文数据集;`code`表示代码数据集;`ja`表示日语数据集;`ar`表示阿拉伯语数据集;`kaoshi`表示考试数据集。
|
||||
|
||||
生成的bin文件的格式如下:
|
||||
```python
|
||||
{"tokens": [73075, 75302, 69522, 69022, 98899, 67713, 68015, 81269, 74637, 75445, 99157]}
|
||||
{"tokens": [69469, 60355, 73026, 68524, 60846, 61844, 98899, 67775, 79241, 98899, 67713, 67800, 67453, 67838, 99157]}
|
||||
{"tokens": [68057, 79017, 60378, 68014, 98899, 67713, 67990, 68015, 70381, 67428, 61003, 67622, 99157]}
|
||||
```
|
||||
`bin`文件中的每一行均对应原始数据集中的每一个句子,表示每个句子的`token`(下文将用sequence指定)。
|
||||
|
||||
生成的`meta`文件的格式如下:
|
||||
```bash
|
||||
(0, 11), (90, 15), (208, 13)
|
||||
```
|
||||
在`meta`文件中,每个元组对应着`bin`文件中每一个`sequence`的元信息。其中,元组的第一个元素表示每个`sequence`在所有`sequence`中的`starting index`,第二个元素表示每个`sequence`中有多少个`tokens`。
|
||||
|
||||
例如,对于第一个`sequence`,`starting index`为 0,有 11 个`tokens`;对于第二个`sequence`,由于第一个`sequence`转换为`string`后的长度为`89`,因此它的`starting index`为 90,有 15 个`tokens`。
|
||||
|
||||
`json`和`jsonl`类型的文件的`bin`和`meta`文件格式和`txt`一致,此处不再赘叙。
|
||||
|
||||
### 数据准备 (微调)
|
||||
|
||||
微调任务的数据集格式与预训练任务保持一致,生成的数据格式为一系列的`bin`和`meta`文件。以下以 Alpaca 数据集为例,介绍微调的数据准备流程。
|
||||
|
||||
1. 下载 [Alpaca 数据集](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json)
|
||||
|
||||
2. 对 Alpaca 数据进行 tokenize,使用以下命令
|
||||
|
||||
```shell
|
||||
python tools/alpaca_tokenizer.py /path/to/alpaca_dataset /path/to/output_dataset /path/to/tokenizer --split_ratio 0.1
|
||||
```
|
||||
|
||||
建议用户参考 alpaca_tokenizer.py 编写新的脚本对自己的数据集进行 tokenize
|
||||
|
||||
### 训练配置
|
||||
|
||||
以 7B Demo 的配置文件`configs/7B_sft.py`为例,介绍启动一个模型训练所需要进行的数据、模型和并行等相关的配置。
|
||||
|
||||
#### 数据配置
|
||||
数据相关的关键参数配置及释义如下所示:
|
||||
```python
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
SEQ_LEN = 2048
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048
|
||||
micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1
|
||||
micro_bsz=1, # packed_length = micro_bsz * SEQ_LEN,为一次处理的 micro_batch 的数据大小,默认值为 1
|
||||
total_steps=50000, # 总的所需执行的 step 的数目,默认值为 50000
|
||||
min_length=50, # 若数据集文件中,数据行数少于50,将会被废弃
|
||||
train_folder=TRAIN_FOLDER, # 数据集文件路径,默认值为 None;若 train_folder 为空,则以自动生成的随机数据集进行训练测试
|
||||
pack_sample_into_one=False, # 数据整理的逻辑,决定是按照 seq_len 维度或者是 sequence 的真实长度来进行attention计算
|
||||
)
|
||||
```
|
||||
|
||||
<div align="left">
|
||||
<img src="./imgs/pack_into_one.png" width="550"/>
|
||||
</div>
|
||||
|
||||
|
||||
目前支持传入数据集文件路径`train_folder`,且要求文件格式如下:
|
||||
```bash
|
||||
- folder
|
||||
- code
|
||||
train_000.bin
|
||||
train_000.bin.meta
|
||||
```
|
||||
数据集的详细内容可参考``数据准备``模块相关的介绍。
|
||||
|
||||
#### 模型配置
|
||||
|
||||
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
|
||||
```python
|
||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
|
||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
|
||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
|
||||
ckpt = dict(
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径
|
||||
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf
|
||||
load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
|
||||
load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True
|
||||
)
|
||||
```
|
||||
注意:
|
||||
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
|
||||
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
|
||||
|
||||
模型相关关键参数配置如下所示:
|
||||
```python
|
||||
model_type = "INTERNLM" # 模型类型,默认值为 "INTERNLM",对应模型结构初始化接口函数
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
VOCAB_SIZE = 103168
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_LAYER = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
model = dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
)
|
||||
```
|
||||
注意:用户可自定义模型类型名和模型结构,并配置相对应的模型参数。通过`utils/registry.py`下的`MODEL_INITIALIZER`对象进行模型初始化函数接口注册,在训练主函数`train.py`中初始化模型时,可通过`model_type`配置获取指定的模型初始化接口函数。
|
||||
|
||||
*如果基于 InternLM 7B继续训练,可以参考 [ModelZoo](https://github.com/InternLM/InternLM/tree/main#model-zoo) 中 OpenXLab 链接下载权重*
|
||||
|
||||
#### 并行配置
|
||||
|
||||
训练并行配置样例如下:
|
||||
```python
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=1,
|
||||
tensor=1,
|
||||
)
|
||||
```
|
||||
- zero1:zero 并行策略,分如下三种情况,默认值为 -1
|
||||
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
|
||||
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
|
||||
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
|
||||
- pipeline:流水线并行大小,目前只支持 1,默认值为 1
|
||||
- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1
|
||||
|
||||
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`
|
||||
|
||||
### 启动训练
|
||||
|
||||
完成了以上数据集准备和相关训练配置后,可启动 Demo 训练。接下来分别以 slurm 和 torch 环境为例,介绍训练启动方式。
|
||||
|
||||
若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示:
|
||||
```bash
|
||||
$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
|
||||
```
|
||||
|
||||
若在 torch 上启动分布式运行环境,单节点 8 卡的运行命令如下所示:
|
||||
```bash
|
||||
$ torchrun --nnodes=1 --nproc-per-node=8 train.py --config ./configs/7B_sft.py
|
||||
```
|
||||
|
||||
### 运行结果
|
||||
|
||||
以 slurm 上单机 8 卡的 Demo 训练配置为例,训练结果日志展示如下:
|
||||
```bash
|
||||
2023-07-04 21:40:14,148 INFO train.py:318 in record_current_batch_training_metrics -- step=17,loss=9.810295104980469,tgs (tokens per gpu per second)=4399.93,lr=3.8e-06,loss_scale=65536.0,grad_norm=4.177205427229359,micro_num=4,num_consumed_tokens=2359296,inf_nan_skip_batches=0,num_samples_in_batch=60,largest_length=1300,largest_batch=18,smallest_batch=13,adam_beta2=0.95,fwd_bwd_time=3.57
|
||||
2023-07-04 21:40:17,825 INFO train.py:318 in record_current_batch_training_metrics -- step=18,loss=9.715232849121094,tgs (tokens per gpu per second)=4457.7,lr=4.000000000000001e-06,loss_scale=65536.0,grad_norm=5.018154183978863,micro_num=4,num_consumed_tokens=2490368,inf_nan_skip_batches=0,num_samples_in_batch=68,largest_length=1153,largest_batch=19,smallest_batch=16,adam_beta2=0.95,fwd_bwd_time=3.52
|
||||
2023-07-04 21:40:21,526 INFO train.py:318 in record_current_batch_training_metrics -- step=19,loss=9.76744556427002,tgs (tokens per gpu per second)=4429.13,lr=4.2000000000000004e-06,loss_scale=65536.0,grad_norm=5.245329823265071,micro_num=4,num_consumed_tokens=2621440,inf_nan_skip_batches=0,num_samples_in_batch=70,largest_length=706,largest_batch=18,smallest_batch=17,adam_beta2=0.95,fwd_bwd_time=3.54
|
||||
2023-07-04 21:40:25,227 INFO train.py:318 in record_current_batch_training_metrics -- step=20,loss=9.628969192504883,tgs (tokens per gpu per second)=4427.46,lr=4.4e-06,loss_scale=65536.0,grad_norm=5.503176552110271,micro_num=4,num_consumed_tokens=2752512,inf_nan_skip_batches=0,num_samples_in_batch=69,largest_length=915,largest_batch=20,smallest_batch=15,adam_beta2=0.95,fwd_bwd_time=3.55
|
||||
2023-07-04 21:40:28,899 INFO train.py:318 in record_current_batch_training_metrics -- step=21,loss=9.690847396850586,tgs (tokens per gpu per second)=4464.18,lr=4.6e-06,loss_scale=65536.0,grad_norm=5.5336643273197526,micro_num=4,num_consumed_tokens=2883584,inf_nan_skip_batches=0,num_samples_in_batch=66,largest_length=870,largest_batch=17,smallest_batch=16,adam_beta2=0.95,fwd_bwd_time=3.52
|
||||
2023-07-04 21:40:32,629 INFO train.py:318 in record_current_batch_training_metrics -- step=22,loss=9.61986255645752,tgs (tokens per gpu per second)=4393.28,lr=4.800000000000001e-06,loss_scale=65536.0,grad_norm=9.01168869536059,micro_num=4,num_consumed_tokens=3014656,inf_nan_skip_batches=0,num_samples_in_batch=65,largest_length=1151,largest_batch=20,smallest_batch=14,adam_beta2=0.95,fwd_bwd_time=3.57
|
||||
```
|
|
@ -0,0 +1,9 @@
|
|||
from .initialize.initialize_trainer import initialize_trainer
|
||||
from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch
|
||||
|
||||
__all__ = [
|
||||
"get_default_parser",
|
||||
"initialize_trainer",
|
||||
"launch_from_slurm",
|
||||
"launch_from_torch",
|
||||
]
|
|
@ -0,0 +1,848 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
__all__ = ["SequenceGenerator"]
|
||||
|
||||
|
||||
class InferenceParams:
|
||||
"""
|
||||
Intermediate cache objects for inference
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sequence_len,
|
||||
max_batch_size,
|
||||
sequence_len_offset=0,
|
||||
batch_size_offset=0,
|
||||
key_value_memory_dict: dict = None,
|
||||
lengths_per_sample=None,
|
||||
attention_mask=None,
|
||||
) -> None:
|
||||
|
||||
self.max_sequence_len: int = max_sequence_len
|
||||
self.max_batch_size: int = max_batch_size
|
||||
self.sequence_len_offset: int = sequence_len_offset
|
||||
self.batch_size_offset: int = batch_size_offset
|
||||
if key_value_memory_dict is None:
|
||||
key_value_memory_dict = {}
|
||||
self.key_value_memory_dict: dict = key_value_memory_dict
|
||||
self.fused_ft_kernel: bool = False
|
||||
self.lengths_per_sample = lengths_per_sample
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
def reorder_state(self, indices):
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0)
|
||||
for key, value in list(self.key_value_memory_dict.items()):
|
||||
value = value.index_select(index=indices, dim=0)
|
||||
self.key_value_memory_dict[key] = value
|
||||
|
||||
|
||||
def _get_model_device(model):
|
||||
"""
|
||||
obtain the device of an nn.Module.model
|
||||
|
||||
Args:
|
||||
model: nn.Module
|
||||
|
||||
Return: torch.device. if None, the parameters of this model is None.
|
||||
"""
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
parameters = list(model.parameters())
|
||||
if len(parameters) == 0:
|
||||
return None
|
||||
else:
|
||||
return parameters[0].device
|
||||
|
||||
|
||||
class SequenceGenerator:
|
||||
"""
|
||||
Sequence Generator.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder, eos_token_id, pad_token_id, bos_token_id):
|
||||
self.decoder = decoder
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
tokens: "torch.LongTensor" = None,
|
||||
num_return_sequences=1,
|
||||
max_length: int = 20,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
repetition_penalty: float = 1,
|
||||
length_penalty: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tokens: the beginning tokens whose shape is [bsz, length]. If shape is None, default ''bos_token'' will be
|
||||
added to conduct generation.
|
||||
num_return_sequences: number of returned sequences.
|
||||
max_length: the max length of generated sequence.
|
||||
num_beams: the size of beam search.
|
||||
do_sample: whether using sample.
|
||||
temperature: it's meaningful when do_sample is True.
|
||||
top_k: sampling from top_k.
|
||||
top_p: sampling from top_p tokens(nucleus sampling).
|
||||
|
||||
Return:
|
||||
the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None,
|
||||
the ending of each sequence must be eos_token_id.
|
||||
"""
|
||||
assert num_return_sequences <= num_beams, f"The `{num_return_sequences}` must be less than `{num_beams}`..."
|
||||
if do_sample:
|
||||
return sample_generate(
|
||||
self.decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_return_sequences,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
eos_token_id=self.eos_token_id, # the ending token id
|
||||
pad_token_id=self.pad_token_id,
|
||||
repetition_penalty=repetition_penalty, # the penalty degree for repetition tokens
|
||||
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
|
||||
# Otherwise, encourages short sequence.
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
else:
|
||||
return greedy_generate(
|
||||
self.decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_return_sequences,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def greedy_generate(
|
||||
decoder,
|
||||
tokens=None,
|
||||
max_length=20,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
eos_token_id=None,
|
||||
pad_token_id=0,
|
||||
repetition_penalty=1,
|
||||
length_penalty=1.0,
|
||||
bos_token_id=1,
|
||||
feat_mask=None,
|
||||
ffn_mask=None,
|
||||
layer_mask=None,
|
||||
):
|
||||
"""
|
||||
Search sequence greedily.
|
||||
|
||||
Args:
|
||||
decoder: the Decoder object.
|
||||
tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
|
||||
max_length: the max length for generated sequence.
|
||||
num_beams: the size of beam to decode.
|
||||
eos_token_id: the ending token id. If None, the decode length is max_length.
|
||||
pad_token_id: the token id of pad.
|
||||
repetition_penalty: the penalty degree for repetition tokens
|
||||
length_penalty: the penalty for length.
|
||||
|
||||
"""
|
||||
if num_beams == 1:
|
||||
token_ids = _no_beam_search_generate(
|
||||
decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
temperature=1,
|
||||
top_k=50,
|
||||
top_p=1,
|
||||
eos_token_id=eos_token_id,
|
||||
do_sample=False,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
feat_mask=feat_mask,
|
||||
ffn_mask=ffn_mask,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
else:
|
||||
token_ids = _beam_search_generate(
|
||||
decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_return_sequences,
|
||||
temperature=1,
|
||||
top_k=50,
|
||||
top_p=1,
|
||||
eos_token_id=eos_token_id,
|
||||
do_sample=False,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
feat_mask=feat_mask,
|
||||
ffn_mask=ffn_mask,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_generate(
|
||||
decoder,
|
||||
tokens,
|
||||
max_length=20,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
eos_token_id=None,
|
||||
pad_token_id=0,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
bos_token_id=1,
|
||||
):
|
||||
"""
|
||||
generate sequence in sampling way.
|
||||
|
||||
Args:
|
||||
decoder: the Decoder object.
|
||||
tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
|
||||
max_length: the max length for generated sequence.
|
||||
num_beams: the size of beam to decode.
|
||||
num_return_sequences: number of returned sequence.
|
||||
temperature: annealing magnitude during sampling.
|
||||
top_k: sampling from top_k. (Default: 50)
|
||||
top_p: sampling from top_p tokens(nucleus sampling). (Default: 1.0)
|
||||
eos_token_id: the ending token id. If None, the decode length is max_length.
|
||||
pad_token_id: the token id of pad.
|
||||
repetition_penalty: the penalty degree for repetition tokens
|
||||
length_penalty: the penalty for length.
|
||||
|
||||
"""
|
||||
if num_beams == 1:
|
||||
token_ids = _no_beam_search_generate(
|
||||
decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
eos_token_id=eos_token_id,
|
||||
do_sample=True,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
)
|
||||
else:
|
||||
token_ids = _beam_search_generate(
|
||||
decoder,
|
||||
tokens=tokens,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_return_sequences,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
eos_token_id=eos_token_id,
|
||||
do_sample=True,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _no_beam_search_generate(
|
||||
decoder,
|
||||
tokens,
|
||||
inference_params=None,
|
||||
max_length=20,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
eos_token_id=None,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
feat_mask=None,
|
||||
ffn_mask=None,
|
||||
layer_mask=None,
|
||||
):
|
||||
# delete num_return_sequences=1 for lint check;
|
||||
batch_size = tokens.size(0)
|
||||
if eos_token_id is None:
|
||||
_eos_token_id = -1
|
||||
else:
|
||||
_eos_token_id = eos_token_id
|
||||
|
||||
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
|
||||
if has_bos:
|
||||
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
||||
bos_sum = bos_pos.cumsum(dim=-1)
|
||||
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
else:
|
||||
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
||||
if inference_params is None:
|
||||
inference_params = InferenceParams(
|
||||
max_sequence_len=max_length,
|
||||
max_batch_size=tokens.size(0),
|
||||
sequence_len_offset=0,
|
||||
batch_size_offset=0,
|
||||
key_value_memory_dict=None,
|
||||
lengths_per_sample=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if layer_mask is None:
|
||||
if feat_mask is None and ffn_mask is None:
|
||||
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": tokens,
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
}
|
||||
)
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": tokens,
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
"layer_mask": layer_mask,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(scores, (list, tuple)):
|
||||
scores = scores[0]
|
||||
scores = scores[:, -1].float()
|
||||
inference_params.sequence_len_offset += tokens.size(1)
|
||||
if _eos_token_id != -1:
|
||||
scores[:, _eos_token_id] = -1e12
|
||||
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
||||
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
||||
cur_len = token_ids.size(1)
|
||||
dones = token_ids.new_zeros(batch_size).eq(1)
|
||||
# tokens = tokens[:, -1:]
|
||||
|
||||
real_max_length = max_length
|
||||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
|
||||
|
||||
while cur_len < real_max_length:
|
||||
# batch_size x vocab_size
|
||||
if has_bos:
|
||||
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
||||
bos_sum = bos_pos.cumsum(dim=-1)
|
||||
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
else:
|
||||
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
||||
inference_params.attention_mask = attention_mask
|
||||
if layer_mask is None:
|
||||
if feat_mask is None and ffn_mask is None:
|
||||
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": token_ids[:, -1:],
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
}
|
||||
)
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": token_ids[:, -1:],
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
"layer_mask": layer_mask,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(scores, (list, tuple)):
|
||||
scores = scores[0]
|
||||
scores = scores[:, -1].float()
|
||||
inference_params.sequence_len_offset += 1
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
token_scores = scores.gather(dim=1, index=token_ids)
|
||||
lt_zero_mask = token_scores.lt(0).float()
|
||||
ge_zero_mask = lt_zero_mask.eq(0).float()
|
||||
token_scores = (
|
||||
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
||||
)
|
||||
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
||||
|
||||
if eos_token_id is not None and length_penalty != 1.0:
|
||||
# batch_size x vocab_size
|
||||
token_scores = scores / cur_len**length_penalty
|
||||
eos_mask = scores.new_ones(scores.size(1))
|
||||
eos_mask[eos_token_id] = 0
|
||||
eos_mask = eos_mask.unsqueeze(0).eq(1)
|
||||
|
||||
scores = scores.masked_scatter(eos_mask, token_scores)
|
||||
|
||||
if do_sample:
|
||||
if temperature > 0 and temperature != 1:
|
||||
scores = scores / temperature
|
||||
|
||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
|
||||
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
|
||||
probs = F.softmax(scores, dim=-1) + 1e-12
|
||||
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
|
||||
else:
|
||||
next_tokens = torch.argmax(scores, dim=-1) # batch_size
|
||||
|
||||
if _eos_token_id != -1:
|
||||
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), _eos_token_id)
|
||||
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
|
||||
tokens = next_tokens.unsqueeze(1)
|
||||
|
||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
||||
|
||||
end_mask = next_tokens.eq(_eos_token_id)
|
||||
dones = dones.__or__(end_mask)
|
||||
cur_len += 1
|
||||
|
||||
if dones.min() == 1:
|
||||
break
|
||||
|
||||
# if eos_token_id is not None:
|
||||
# # setting the eos at the maximum length position
|
||||
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id)
|
||||
# if cur_len == max_length:
|
||||
# # If eos is not reached by the maximum length, forcibly replace the last word with eos
|
||||
# token_ids[:, -1].masked_fill_(~dones, eos_token_id)
|
||||
# TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
|
||||
# be able to return multiple real results
|
||||
return token_ids[:, None]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _beam_search_generate(
|
||||
decoder,
|
||||
tokens,
|
||||
inference_params=None,
|
||||
max_length=20,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
eos_token_id=None,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
feat_mask=None,
|
||||
ffn_mask=None,
|
||||
layer_mask=None,
|
||||
) -> torch.LongTensor:
|
||||
|
||||
device = _get_model_device(decoder)
|
||||
batch_size = tokens.size(0)
|
||||
|
||||
if eos_token_id is None:
|
||||
_eos_token_id = -1
|
||||
else:
|
||||
_eos_token_id = eos_token_id
|
||||
|
||||
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
|
||||
|
||||
if has_bos:
|
||||
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
||||
bos_sum = bos_pos.cumsum(dim=-1)
|
||||
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
else:
|
||||
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
||||
|
||||
if inference_params is None:
|
||||
inference_params = InferenceParams(
|
||||
max_sequence_len=max_length,
|
||||
max_batch_size=tokens.size(0),
|
||||
sequence_len_offset=0,
|
||||
batch_size_offset=0,
|
||||
key_value_memory_dict=None,
|
||||
lengths_per_sample=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if layer_mask is None:
|
||||
if feat_mask is None and ffn_mask is None:
|
||||
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": tokens,
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
}
|
||||
)
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": tokens,
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
"layer_mask": layer_mask,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(scores, (list, tuple)):
|
||||
scores = scores[0]
|
||||
scores = scores[:, -1].float()
|
||||
inference_params.sequence_len_offset += tokens.size(1)
|
||||
if _eos_token_id != -1:
|
||||
scores[:, _eos_token_id] = -1e12
|
||||
vocab_size = scores.size(1)
|
||||
assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size."
|
||||
|
||||
if do_sample:
|
||||
probs = F.softmax(scores, dim=-1) + 1e-12
|
||||
# (batch_size, num_beams)
|
||||
next_tokens = torch.multinomial(probs, num_samples=num_beams)
|
||||
logits = probs.log()
|
||||
# (batch_size, num_beams)
|
||||
next_scores = logits.gather(dim=1, index=next_tokens)
|
||||
else:
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
|
||||
# obtain (batch_size, num_beams), (batch_size, num_beams)
|
||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
|
||||
|
||||
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
||||
indices = indices.repeat_interleave(num_beams)
|
||||
inference_params.reorder_state(indices)
|
||||
|
||||
# batch_size * num_beams x length
|
||||
tokens = tokens.index_select(dim=0, index=indices)
|
||||
# genrated token (batch_size', cur_len)
|
||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
|
||||
dones = [False] * batch_size
|
||||
|
||||
beam_scores = next_scores.view(-1) # batch_size * num_beams
|
||||
|
||||
cur_len = token_ids.size(1)
|
||||
|
||||
real_max_length = max_length
|
||||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
|
||||
hypos = [
|
||||
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
||||
]
|
||||
# 0, num_beams, 2*num_beams, ...
|
||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
|
||||
|
||||
while cur_len < real_max_length:
|
||||
if has_bos:
|
||||
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
||||
bos_sum = bos_pos.cumsum(dim=-1)
|
||||
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
else:
|
||||
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
||||
to_atten_x = bos_pos[:, :, None]
|
||||
to_atten_y = bos_pos[:, None, :]
|
||||
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
||||
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
||||
|
||||
inference_params.attention_mask = attention_mask
|
||||
# (bsz x num_beams, vocab_size)
|
||||
|
||||
if layer_mask is None:
|
||||
if feat_mask is None and ffn_mask is None:
|
||||
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": token_ids[:, -1:],
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
}
|
||||
)
|
||||
else:
|
||||
scores = decoder(
|
||||
**{
|
||||
"input_ids": token_ids[:, -1:],
|
||||
"inference_params": inference_params,
|
||||
"feat_mask": feat_mask,
|
||||
"ffn_mask": ffn_mask,
|
||||
"layer_mask": layer_mask,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(scores, (list, tuple)):
|
||||
scores = scores[0]
|
||||
scores = scores[:, -1].float()
|
||||
inference_params.sequence_len_offset += 1
|
||||
if repetition_penalty != 1.0:
|
||||
token_scores = scores.gather(dim=1, index=token_ids)
|
||||
lt_zero_mask = token_scores.lt(0).float()
|
||||
ge_zero_mask = lt_zero_mask.eq(0).float()
|
||||
token_scores = (
|
||||
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
||||
)
|
||||
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
||||
|
||||
if _eos_token_id != -1:
|
||||
max_len_eos_mask = max_lengths.eq(cur_len + 1)
|
||||
eos_scores = scores[:, _eos_token_id]
|
||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores)
|
||||
|
||||
if do_sample:
|
||||
if temperature > 0 and temperature != 1:
|
||||
scores = scores / temperature
|
||||
|
||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
|
||||
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
|
||||
probs = F.softmax(scores, dim=-1) + 1e-12
|
||||
|
||||
# batch_size' x (num_beams+1)
|
||||
_tokens = torch.multinomial(probs, num_samples=num_beams + 1)
|
||||
|
||||
logits = probs.log()
|
||||
# batch_size' x (num_beams+1)
|
||||
_scores = logits.gather(dim=1, index=_tokens)
|
||||
# batch_size' x (num_beams+1)
|
||||
_scores = _scores + beam_scores[:, None]
|
||||
_scores = _scores.view(batch_size, num_beams * (num_beams + 1))
|
||||
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
|
||||
# (batch_size, 2*num_beams)
|
||||
next_tokens = _tokens.gather(dim=1, index=ids)
|
||||
# (batch_size, 2*num_beams)
|
||||
from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long()
|
||||
else:
|
||||
# (batch_size * num_beams, vocab_size)
|
||||
scores = F.log_softmax(scores, dim=-1)
|
||||
# (batch_size * num_beams, vocab_size)
|
||||
_scores = scores + beam_scores[:, None]
|
||||
# (batch_size, num_beams*vocab_size)
|
||||
_scores = _scores.view(batch_size, -1)
|
||||
# (bsz, 2*num_beams)
|
||||
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
# (batch_size, 2*num_beams)
|
||||
from_which_beam = torch.floor(ids.float() / vocab_size).long()
|
||||
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
|
||||
|
||||
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
|
||||
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
|
||||
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
|
||||
|
||||
not_eos_mask = next_tokens.ne(_eos_token_id)
|
||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
|
||||
keep_mask = not_eos_mask.__and__(keep_mask)
|
||||
|
||||
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
|
||||
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
beam_scores = _next_scores.view(-1)
|
||||
|
||||
flag = True
|
||||
if cur_len + 1 == real_max_length:
|
||||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
|
||||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
|
||||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
|
||||
else:
|
||||
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
|
||||
if effective_eos_mask.sum().gt(0):
|
||||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
|
||||
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
|
||||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
|
||||
else:
|
||||
flag = False
|
||||
|
||||
if flag:
|
||||
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
|
||||
for batch_idx, beam_ind, beam_idx in zip(
|
||||
eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()
|
||||
):
|
||||
if not dones[batch_idx]:
|
||||
score = next_scores[batch_idx, beam_ind].item()
|
||||
if _eos_token_id != -1:
|
||||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
|
||||
else:
|
||||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
|
||||
|
||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
|
||||
inference_params.reorder_state(reorder_inds)
|
||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
dones[batch_idx] = (
|
||||
dones[batch_idx]
|
||||
or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
|
||||
or max_lengths[batch_idx * num_beams] == cur_len + 1
|
||||
)
|
||||
|
||||
cur_len += 1
|
||||
|
||||
if all(dones):
|
||||
break
|
||||
|
||||
# select the best hypotheses
|
||||
tgt_len = token_ids.new_zeros(batch_size, num_return_sequences)
|
||||
best = []
|
||||
|
||||
for i, hypotheses in enumerate(hypos):
|
||||
# best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
sorted_hyp = list(sorted(hypotheses.hyp, key=lambda x: x[0], reverse=True))
|
||||
_best = []
|
||||
for j, hyp in zip(range(num_return_sequences), sorted_hyp):
|
||||
hyp = hyp[1]
|
||||
if _eos_token_id != -1:
|
||||
hyp = torch.cat([hyp, token_ids.new_ones(1) * _eos_token_id])
|
||||
tgt_len[i, j] = len(hyp)
|
||||
_best.append(hyp)
|
||||
best.append(_best)
|
||||
|
||||
# generate target batch
|
||||
decoded = token_ids.new_zeros(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id)
|
||||
for i, hypo in enumerate(best):
|
||||
for j, _hypo in enumerate(hypo):
|
||||
decoded[i, j, : tgt_len[i, j]] = _hypo
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
"""
|
||||
BeamHypotheses
|
||||
"""
|
||||
|
||||
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
|
||||
"""Initialize n-best list of hypotheses."""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.num_beams = num_beams
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""Number of hypotheses in the list."""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""Add a new hypothesis to the list."""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs):
|
||||
"""If there are enough hypotheses and that none of the hypotheses being
|
||||
generated can become better than the worst one in the heap, then we are
|
||||
done with this sentence."""
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / self.max_length**self.length_penalty
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
"""
|
||||
Based on the values of top_k and top_p, set the values that do not meet the criteria to the filter_value.
|
||||
|
||||
Args:
|
||||
logits: logit value, shape is [bsz, vocab_size].
|
||||
top_k: If it is greater than 0, only the probabilities of the top_k vocabulary are kept, and the rest of
|
||||
the positions are set to filter_value.
|
||||
top_p: according to http://arxiv.org/abs/1904.09751.
|
||||
filter_value: filter value
|
||||
min_tokens_to_keep: The probability of words in each sample‘s returned distribution will not be
|
||||
lower than this value.
|
||||
|
||||
"""
|
||||
if top_k > 0:
|
||||
# Safety check
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
|
||||
# Remove all tokens with a probability less than the last token of
|
||||
# the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
# (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
if min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
# (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||
# Shift the indices to the right to keep also the first token
|
||||
# above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
|
@ -0,0 +1,9 @@
|
|||
from .engine import Engine
|
||||
from .naive_amp import NaiveAMPModel
|
||||
from .trainer import Trainer
|
||||
|
||||
__all__ = [
|
||||
"NaiveAMPModel",
|
||||
"Engine",
|
||||
"Trainer",
|
||||
]
|
|
@ -0,0 +1,47 @@
|
|||
from .parallel_context import (
|
||||
IS_TENSOR_PARALLEL,
|
||||
Config,
|
||||
ParallelContext,
|
||||
global_context,
|
||||
)
|
||||
from .process_group_initializer import (
|
||||
Initializer_Data,
|
||||
Initializer_Model,
|
||||
Initializer_Pipeline,
|
||||
Initializer_Tensor,
|
||||
Initializer_Zero1,
|
||||
ParallelMode,
|
||||
ProcessGroupInitializer,
|
||||
)
|
||||
from .random import (
|
||||
add_seed,
|
||||
get_current_mode,
|
||||
get_seeds,
|
||||
get_states,
|
||||
seed,
|
||||
set_mode,
|
||||
set_seed_states,
|
||||
sync_states,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"IS_TENSOR_PARALLEL",
|
||||
"global_context",
|
||||
"ParallelContext",
|
||||
"ParallelMode",
|
||||
"Initializer_Tensor",
|
||||
"Initializer_Pipeline",
|
||||
"Initializer_Data",
|
||||
"Initializer_Zero1",
|
||||
"ProcessGroupInitializer",
|
||||
"Initializer_Model",
|
||||
"seed",
|
||||
"set_mode",
|
||||
"add_seed",
|
||||
"get_seeds",
|
||||
"get_states",
|
||||
"get_current_mode",
|
||||
"set_seed_states",
|
||||
"sync_states",
|
||||
]
|
|
@ -0,0 +1,548 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
||||
|
||||
import inspect
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
from collections import Counter
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
from . import process_group_initializer as pgroup_initializer
|
||||
from .process_group_initializer import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
|
||||
IS_TENSOR_PARALLEL = "is_tensor_parallel"
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class Config(dict):
|
||||
"""This is a wrapper class for dict objects so that values of which can be
|
||||
accessed as attributes.
|
||||
|
||||
Args:
|
||||
config (dict): The dict object to be wrapped.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
if config is not None:
|
||||
for k, v in config.items():
|
||||
self._add_item(k, v)
|
||||
|
||||
def __missing__(self, key):
|
||||
raise KeyError(key)
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
value = super().__getitem__(key)
|
||||
return value
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def _add_item(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
self.__setattr__(key, Config(value))
|
||||
else:
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def update(self, config):
|
||||
assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
|
||||
for k, v in config.items():
|
||||
self._add_item(k, v)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def from_file(filename: str):
|
||||
"""Reads a python file and constructs a corresponding :class:`Config` object.
|
||||
|
||||
Args:
|
||||
filename (str): Name of the file to construct the return object.
|
||||
|
||||
Returns:
|
||||
:class:`Config`: A :class:`Config` object constructed with information in the file.
|
||||
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
|
||||
"""
|
||||
|
||||
# check config path
|
||||
if isinstance(filename, str):
|
||||
filepath = Path(filename).absolute()
|
||||
elif isinstance(filename, Path):
|
||||
filepath = filename.absolute()
|
||||
|
||||
assert filepath.exists(), f"{filename} is not found, please check your configuration path"
|
||||
|
||||
# check extension
|
||||
extension = filepath.suffix
|
||||
assert extension == ".py", "only .py files are supported"
|
||||
|
||||
# import the config as module
|
||||
remove_path = False
|
||||
if filepath.parent not in sys.path:
|
||||
sys.path.insert(0, (filepath))
|
||||
remove_path = True
|
||||
|
||||
module_name = filepath.stem
|
||||
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
|
||||
module = source_file.load_module() # pylint: disable=W4902,E1120
|
||||
|
||||
# load into config
|
||||
config = Config()
|
||||
|
||||
for k, v in module.__dict__.items():
|
||||
if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
|
||||
continue
|
||||
else:
|
||||
config._add_item(k, v)
|
||||
|
||||
# remove module
|
||||
del sys.modules[module_name]
|
||||
if remove_path:
|
||||
sys.path.pop(0)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class ParallelContext(metaclass=SingletonMeta):
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
self._world_sizes = dict()
|
||||
self._groups = dict()
|
||||
self._cpu_groups = dict()
|
||||
self._ranks_in_group = dict()
|
||||
|
||||
# load config from file
|
||||
self._config = None
|
||||
|
||||
# default parallel args, will be overwritten during process group intialization
|
||||
self.world_size = 1
|
||||
self.data_parallel_size = 1
|
||||
self.pipeline_parallel_size = 1
|
||||
self.tensor_parallel_size = 1
|
||||
self.zero1_parallel_size = -1
|
||||
self.num_processes_on_current_node = -1
|
||||
self.virtual_pipeline_parallel_size = None
|
||||
self.virtual_pipeline_parallel_rank = None
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
def load_config(self, config: Union[dict, str]):
|
||||
"""Loads the configuration from either a dict or a file.
|
||||
|
||||
Args:
|
||||
config (dict or str): Either a dict containing the configuration information or the filename
|
||||
of a file containing the configuration information.
|
||||
|
||||
Raises:
|
||||
TypeError: Raises a TypeError if `config` is neither a dict nor a str.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
self._config = Config.from_file(config)
|
||||
elif isinstance(config, dict):
|
||||
self._config = Config(config)
|
||||
else:
|
||||
raise TypeError("Invalid type for config, only dictionary or string is supported")
|
||||
|
||||
def detect_num_processes_on_current_node(self):
|
||||
hostname = socket.gethostname()
|
||||
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
|
||||
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
|
||||
counter = Counter(hostname_list)
|
||||
self.num_processes_on_current_node = counter[hostname]
|
||||
|
||||
@staticmethod
|
||||
def _check_parallel_mode(parallel_mode: ParallelMode):
|
||||
assert isinstance(
|
||||
parallel_mode, ParallelMode
|
||||
), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}"
|
||||
|
||||
def get_global_rank(self):
|
||||
"""Returns the global rank of the current device.
|
||||
|
||||
Returns:
|
||||
int: The global rank of the current device
|
||||
"""
|
||||
return self._global_ranks[ParallelMode.GLOBAL]
|
||||
|
||||
def get_local_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the local rank of the current device.
|
||||
|
||||
Args:
|
||||
parallel_mode: The parallel mode for the rank.
|
||||
|
||||
Returns:
|
||||
int: The local rank of the current device for `parallel_mode`.
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._local_ranks.get(parallel_mode, 0)
|
||||
|
||||
def get_next_global_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the global rank of the next device.
|
||||
|
||||
Args:
|
||||
parallel_mode: The parallel mode for the rank.
|
||||
|
||||
Returns:
|
||||
int: The global rank of the next device for `parallel_mode`.
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
|
||||
# get rank and world size
|
||||
local_rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
||||
|
||||
return ranks_in_group[(local_rank + 1) % world_size]
|
||||
|
||||
def get_prev_global_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns the global rank of the previous device.
|
||||
|
||||
Args:
|
||||
parallel_mode: The chosen parallel mode.
|
||||
|
||||
Returns:
|
||||
int: The global rank of the previous device for `parallel_mode`.
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
|
||||
# get rank and world size
|
||||
local_rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
||||
|
||||
return ranks_in_group[(local_rank - 1) % world_size]
|
||||
|
||||
def is_using_dp(self):
|
||||
"""Returns a boolean value indicating whether the current device is initilized with
|
||||
ParallelMode.DATA and its world_size is greater than 1.
|
||||
"""
|
||||
return self.is_initialized(ParallelMode.DATA) and self.get_world_size(ParallelMode.DATA) > 1
|
||||
|
||||
def is_using_tp(self):
|
||||
"""Returns a boolean value indicating whether the current device is initilized with
|
||||
ParallelMode.TENSOR and its world_size is greater than 1.
|
||||
"""
|
||||
return self.is_initialized(ParallelMode.TENSOR) and self.get_world_size(ParallelMode.TENSOR) > 1
|
||||
|
||||
def is_using_pp(self):
|
||||
"""Returns a boolean value indicating whether the current device is initilized with
|
||||
ParallelMode.PIPELINE and its world_size is greater than 1.
|
||||
"""
|
||||
return self.is_initialized(ParallelMode.PIPELINE) and self.get_world_size(ParallelMode.PIPELINE) > 1
|
||||
|
||||
def is_using_sequence(self):
|
||||
"""Returns a boolean value indicating whether the current device is initilized with
|
||||
ParallelMode.SEQUENCE and its world_size is greater than 1.
|
||||
"""
|
||||
return False
|
||||
# return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
|
||||
|
||||
def is_first_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`.
|
||||
|
||||
Args:
|
||||
parallel_mode: The chosen parallel mode.
|
||||
|
||||
Returns:
|
||||
bool: a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`.
|
||||
"""
|
||||
rank = 0
|
||||
if self.is_initialized(parallel_mode):
|
||||
rank = self.get_local_rank(parallel_mode)
|
||||
return rank == 0
|
||||
|
||||
def is_rank_for_log(self):
|
||||
"""Returns a boolean value indicating whether the current device should print log."""
|
||||
is_log_rank = (
|
||||
self.is_first_rank(ParallelMode.DATA)
|
||||
and self.is_first_rank(ParallelMode.TENSOR)
|
||||
and self.is_last_rank(ParallelMode.PIPELINE)
|
||||
)
|
||||
return is_log_rank
|
||||
|
||||
def is_last_rank(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether the current device is the last one
|
||||
among its group for `parallel_mode`.
|
||||
|
||||
Args:
|
||||
parallel_mode: The chosen parallel mode.
|
||||
|
||||
Returns:
|
||||
bool: a boolean value indicating whether the current device is the first one
|
||||
among its group for `parallel_mode`.
|
||||
"""
|
||||
rank = 0
|
||||
world_size = 1
|
||||
if self.is_initialized(parallel_mode):
|
||||
rank = self.get_local_rank(parallel_mode)
|
||||
world_size = self.get_world_size(parallel_mode)
|
||||
return rank == world_size - 1
|
||||
|
||||
def is_pipeline_first_stage(self, ignore_virtual=False):
|
||||
if not ignore_virtual:
|
||||
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
|
||||
return False
|
||||
return self.is_first_rank(ParallelMode.PIPELINE)
|
||||
|
||||
def is_pipeline_last_stage(self, ignore_virtual=False):
|
||||
if not ignore_virtual:
|
||||
if (
|
||||
self.virtual_pipeline_parallel_size is not None
|
||||
and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1
|
||||
):
|
||||
return False
|
||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
def get_world_size(self, parallel_mode: ParallelMode):
|
||||
"""Returns the world size for `parallel_mode`.
|
||||
|
||||
Args:
|
||||
parallel_mode: The chosen parallel mode.
|
||||
|
||||
Returns:
|
||||
int: The world size for `parallel_mode`.
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._world_sizes.get(parallel_mode, 1)
|
||||
|
||||
def get_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the group of the current device for `parallel_mode`.
|
||||
|
||||
Args:
|
||||
parallel_mode: The chosen parallel mode.
|
||||
|
||||
Returns:
|
||||
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._groups[parallel_mode]
|
||||
|
||||
def get_ranks_in_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the rank of the current device for `parallel_mode` in the group.
|
||||
|
||||
Args:
|
||||
parallel_mode: The chosen parallel mode.
|
||||
|
||||
Returns:
|
||||
int: The rank of the current device for `parallel_mode` in the group.
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._ranks_in_group[parallel_mode]
|
||||
|
||||
def get_cpu_group(self, parallel_mode: ParallelMode):
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._cpu_groups[parallel_mode]
|
||||
|
||||
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int, use_cpu: bool = False):
|
||||
"""Initializes the global distributed environment
|
||||
|
||||
Args:
|
||||
rank (int): rank for the default process group.
|
||||
world_size (int): world size of the default process group.
|
||||
backend (str): backend for ``torch.distributed``
|
||||
host (str): the master address for distributed training.
|
||||
port (str): the master port for distributed training.
|
||||
use_cpu (bool): whether to set up cpu process group.
|
||||
"""
|
||||
# initialize the default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
ranks = list(range(world_size))
|
||||
if use_cpu:
|
||||
cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
|
||||
else:
|
||||
cpu_group = None
|
||||
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
|
||||
self._global_ranks[ParallelMode.GLOBAL] = rank
|
||||
|
||||
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
|
||||
self._check_parallel_mode(mode)
|
||||
self._local_ranks[mode] = local_rank
|
||||
self._world_sizes[mode] = world_size
|
||||
self._groups[mode] = process_group
|
||||
self._cpu_groups[mode] = cpu_group
|
||||
self._ranks_in_group[mode] = ranks_in_group
|
||||
|
||||
def check_sanity(self):
|
||||
"""Checks sanity of the parallel context.
|
||||
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if the world size does not equal to the product
|
||||
of data parallel size, pipeline parallel size and tensor parallel size.
|
||||
"""
|
||||
dps = self.data_parallel_size
|
||||
pps = self.pipeline_parallel_size
|
||||
tps = self.tensor_parallel_size
|
||||
ws = self.world_size
|
||||
assert ws == dps * pps * tps, (
|
||||
f"Expected the world size {ws} to be equal to data"
|
||||
f" parallel size ({dps}) * pipeline parallel size "
|
||||
f"({pps}) * tensor parallel size ({tps})"
|
||||
)
|
||||
assert self.zero1_parallel_size > 0
|
||||
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||
|
||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||
if key in config:
|
||||
ele = config[key]
|
||||
if isinstance(ele, int):
|
||||
setattr(self, attr_name, ele)
|
||||
elif isinstance(ele, dict):
|
||||
setattr(self, attr_name, ele["size"])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}'
|
||||
)
|
||||
|
||||
def init_parallel_groups(self):
|
||||
"""Initializes the parallel groups."""
|
||||
|
||||
# get rank and world size
|
||||
rank = self.get_global_rank()
|
||||
world_size = self.get_world_size(ParallelMode.GLOBAL)
|
||||
self.world_size = world_size
|
||||
|
||||
# set parallel size as attributes for global context
|
||||
parallel_config = self.config.get("parallel", None)
|
||||
if parallel_config is not None:
|
||||
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
|
||||
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
|
||||
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
|
||||
|
||||
# the user should not set the data parallel size manually
|
||||
# instead, it should be calculated based on other parallel config
|
||||
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
|
||||
|
||||
if self.zero1_parallel_size <= 0:
|
||||
self.zero1_parallel_size = self.data_parallel_size
|
||||
|
||||
self.check_sanity()
|
||||
|
||||
initializer_args = [
|
||||
rank,
|
||||
world_size,
|
||||
self.data_parallel_size,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.zero1_parallel_size,
|
||||
]
|
||||
|
||||
# run initialization of different process groups
|
||||
initializers = []
|
||||
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||
|
||||
for initializer in initializers:
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
for args in parallel_setting:
|
||||
self._register_dist(*args)
|
||||
else:
|
||||
self._register_dist(*parallel_setting)
|
||||
|
||||
def is_initialized(self, parallel_mode: ParallelMode):
|
||||
"""Returns a boolean value indicating whether `parallel_mode` is initialized
|
||||
in the current system.
|
||||
"""
|
||||
return parallel_mode in self._groups
|
||||
|
||||
def destroy(self):
|
||||
"""Destroys the current distributed parallel environment."""
|
||||
for mode, group in self._groups.items():
|
||||
if mode is not ParallelMode.GLOBAL:
|
||||
dist.destroy_process_group(group)
|
||||
# destroy global process group
|
||||
dist.destroy_process_group()
|
||||
self._groups.clear()
|
||||
|
||||
def set_device(self, device_ordinal: int = None):
|
||||
"""Sets distributed processes to be bound to devices.
|
||||
|
||||
Args:
|
||||
device_ordinal (int, optional): the device id to be bound to
|
||||
"""
|
||||
global_rank = self.get_global_rank()
|
||||
if device_ordinal is None:
|
||||
devices_per_node = torch.cuda.device_count()
|
||||
device_ordinal = global_rank % devices_per_node
|
||||
|
||||
torch.cuda.set_device(device_ordinal)
|
||||
logger.info(f"process rank {global_rank} is bound to host:{socket.gethostname()} device: {device_ordinal}")
|
||||
|
||||
def set_seed(self, seed: int, dpseed_with_tpoffset: bool = False):
|
||||
"""Sets seeds for all random libraries.
|
||||
|
||||
Args:
|
||||
seed (int): seed for random states
|
||||
"""
|
||||
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
|
||||
global_rank = self.get_global_rank()
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
# data parallel seed are kept the same in the same pipeline stage
|
||||
dp_seed = seed
|
||||
if dpseed_with_tpoffset:
|
||||
dp_seed = seed + pipeline_offset * 1024
|
||||
add_seed(ParallelMode.DATA, dp_seed)
|
||||
|
||||
# model parallel seeds are different across ranks
|
||||
if self.is_initialized(ParallelMode.TENSOR):
|
||||
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
|
||||
tp_seed = seed + tp_rank + pipeline_offset * 1024
|
||||
add_seed(ParallelMode.TENSOR, tp_seed)
|
||||
|
||||
set_mode(ParallelMode.DATA)
|
||||
|
||||
seeds = get_seeds()
|
||||
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
|
||||
logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}."
|
||||
)
|
||||
|
||||
def set_virtual_pipeline_parallel_size(self, size):
|
||||
self.virtual_pipeline_parallel_size = size
|
||||
|
||||
def set_virtual_pipeline_parallel_rank(self, rank):
|
||||
self.virtual_pipeline_parallel_rank = rank
|
||||
|
||||
|
||||
global_context = ParallelContext()
|
|
@ -0,0 +1,334 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# parallel modes
|
||||
class ParallelMode(Enum):
|
||||
"""This is an enumeration class containing all possible parallel modes."""
|
||||
|
||||
GLOBAL = "global"
|
||||
|
||||
# common parallel
|
||||
DATA = "data"
|
||||
|
||||
# model parallel - containing tensor and pipeline parallel groups
|
||||
# this is added to facilitate amp and grad clipping in hybrid parallel
|
||||
MODEL = "model"
|
||||
|
||||
# pipeline parallel
|
||||
PIPELINE = "pipe"
|
||||
|
||||
# containing all ranks in tensor parallel
|
||||
TENSOR = "tensor"
|
||||
|
||||
# zero1 parallel
|
||||
ZERO1 = "zero1"
|
||||
|
||||
|
||||
class ProcessGroupInitializer(ABC):
|
||||
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
zero1_parallel_size: int,
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.zero1_parallel_size = zero1_parallel_size
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
pass
|
||||
|
||||
|
||||
class Initializer_Data(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for data parallelism.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
||||
|
||||
assert self.world_size % self.data_parallel_size == 0
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
A Data parallelism's information tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.DATA
|
||||
|
||||
for i in range(self.rank_num_per_dp_group):
|
||||
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_Model(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
|
||||
groups).
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.rank_num_per_group = self.tensor_parallel_size * self.pipeline_parallel_size
|
||||
self.num_group = self.world_size // self.rank_num_per_group
|
||||
|
||||
assert self.world_size % self.rank_num_per_group == 0
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize model parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
A Model parallelism's information tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.MODEL
|
||||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
|
||||
group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_Pipeline(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for pipeline parallelism.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process
|
||||
world_size (int): Size of whole communication world
|
||||
data_parallel_size (int): Size of data parallel
|
||||
pipeline_parallel_size (int): Size of pipeline parallel
|
||||
tensor_parallel_size (int): Size of tensor parallel
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
||||
self.pipeline_stage_size = self.rank_num_per_dp_group // self.pipeline_parallel_size
|
||||
|
||||
assert self.world_size % self.data_parallel_size == 0
|
||||
assert self.rank_num_per_dp_group % self.pipeline_parallel_size == 0
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
|
||||
A Pipeline parallelism's information in list of tuples.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PIPELINE
|
||||
|
||||
for i in range(self.data_parallel_size):
|
||||
for j in range(self.pipeline_stage_size):
|
||||
ranks = list(
|
||||
range(
|
||||
i * self.rank_num_per_dp_group + j,
|
||||
(i + 1) * self.rank_num_per_dp_group,
|
||||
self.pipeline_stage_size,
|
||||
)
|
||||
)
|
||||
pipe_group_size = len(ranks)
|
||||
pipe_group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = pipe_group_size
|
||||
process_group = pipe_group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_Tensor(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for tensor parallelism.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size
|
||||
|
||||
assert self.world_size % self.tensor_parallel_size == 0
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
A Tensor parallelism's information tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.TENSOR
|
||||
|
||||
for i in range(self.num_tensor_parallel_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_Zero1(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for zero-1 parallelism.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero-1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
||||
self.num_zero1_parallel_group = self.data_parallel_size // self.zero1_parallel_size
|
||||
|
||||
assert self.world_size % self.data_parallel_size == 0
|
||||
assert self.world_size % self.zero1_parallel_size == 0
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize zero1 parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
A zero1 parallelism's information tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.ZERO1
|
||||
|
||||
for i in range(self.rank_num_per_dp_group):
|
||||
for j in range(self.num_zero1_parallel_group):
|
||||
ranks = [
|
||||
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
|
||||
for k in range(self.zero1_parallel_size)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
|
@ -0,0 +1,131 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch import Tensor
|
||||
|
||||
from .process_group_initializer import ParallelMode
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""This class is a manager of all random seeds involved in the system."""
|
||||
|
||||
def __init__(self):
|
||||
self._current_mode = None
|
||||
self._seeds = {}
|
||||
self._seed_states = {}
|
||||
|
||||
@property
|
||||
def current_mode(self):
|
||||
return self._current_mode
|
||||
|
||||
@property
|
||||
def seeds(self):
|
||||
return self._seeds
|
||||
|
||||
@property
|
||||
def seed_states(self):
|
||||
return self._seed_states
|
||||
|
||||
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
|
||||
"""Sets the state of the seed manager for `parallel_mode`."""
|
||||
assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
|
||||
self._seed_states[parallel_mode] = state
|
||||
|
||||
def set_mode(self, parallel_mode: ParallelMode):
|
||||
"""Sets the current mode of the seed manager."""
|
||||
if self.current_mode:
|
||||
# save state for current mode
|
||||
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
|
||||
|
||||
# set new state for new mode
|
||||
self._current_mode = parallel_mode
|
||||
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
|
||||
|
||||
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`."""
|
||||
assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode"
|
||||
if not overwrite:
|
||||
assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists"
|
||||
elif parallel_mode in self._seed_states:
|
||||
print(f"Warning: {parallel_mode} seed overwritten.", flush=True)
|
||||
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.manual_seed(seed)
|
||||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
||||
self._seeds[parallel_mode] = seed
|
||||
torch.cuda.set_rng_state(current_state)
|
||||
|
||||
def reset(self):
|
||||
self._current_mode = None
|
||||
self._seeds = {}
|
||||
self._seed_states = {}
|
||||
|
||||
|
||||
_SEED_MANAGER = SeedManager()
|
||||
|
||||
|
||||
def get_seeds():
|
||||
"""Returns the seeds of the seed manager.
|
||||
Returns:
|
||||
dict: The seeds of the seed manager.
|
||||
"""
|
||||
return _SEED_MANAGER.seeds
|
||||
|
||||
|
||||
def get_states(copy=False):
|
||||
"""Returns the seed states of the seed manager.
|
||||
Returns:
|
||||
dict: The seed states of the seed manager.
|
||||
"""
|
||||
states = _SEED_MANAGER.seed_states
|
||||
if copy:
|
||||
new_states = dict()
|
||||
for parallel_mode, state in states.items():
|
||||
new_states[parallel_mode] = state.clone()
|
||||
return new_states
|
||||
else:
|
||||
return _SEED_MANAGER.seed_states
|
||||
|
||||
|
||||
def get_current_mode():
|
||||
"""Returns the current mode of the seed manager.
|
||||
Returns:
|
||||
:class:`torch.ByteTensor`: The current mode of the seed manager.
|
||||
"""
|
||||
return _SEED_MANAGER.current_mode
|
||||
|
||||
|
||||
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
|
||||
"""Adds a seed to the seed manager for `parallel_mode`."""
|
||||
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
|
||||
|
||||
|
||||
def set_mode(parallel_mode: ParallelMode):
|
||||
"""Sets the current mode of the seed manager."""
|
||||
_SEED_MANAGER.set_mode(parallel_mode)
|
||||
|
||||
|
||||
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
|
||||
"""Sets the state of the seed manager for `parallel_mode`."""
|
||||
_SEED_MANAGER.set_state(parallel_mode, state)
|
||||
|
||||
|
||||
def sync_states():
|
||||
current_mode = get_current_mode()
|
||||
current_states = torch.cuda.get_rng_state()
|
||||
set_seed_states(current_mode, current_states)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seed(parallel_mode: ParallelMode):
|
||||
"""A context for seed switch"""
|
||||
current_mode = _SEED_MANAGER.current_mode
|
||||
try:
|
||||
yield _SEED_MANAGER.set_mode(parallel_mode)
|
||||
finally:
|
||||
_SEED_MANAGER.set_mode(current_mode)
|
|
@ -0,0 +1,190 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from internlm.core.gradient_handler import BaseGradientHandler
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||
from internlm.utils.common import get_batch_size, move_to_device
|
||||
|
||||
|
||||
class Engine:
|
||||
"""
|
||||
The Engine class is responsible for managing the training and evaluation process of a neural network model.
|
||||
It handles the forward and backward passes, parameter updates, gradient handling, and mode switching between
|
||||
training and evaluation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The neural network model to be trained or evaluated.
|
||||
optimizer (BaseOptimizer): The optimizer used for updating the parameters of the model.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler for the optimizer.
|
||||
Default is None.
|
||||
beta2_scheduler (internlm.solver.beta2_scheduler.Beta2Scheduler, optional): The beta2 scheduler for the
|
||||
optimizer. Default is None.
|
||||
criterion (torch.nn.modules.loss._Loss, optional): The loss function used for calculating the loss during
|
||||
training. Default is None.
|
||||
gradient_handlers (List[BaseGradientHandler], optional): A list of gradient handlers used in the backward pass.
|
||||
Default is None.
|
||||
clip_grad_norm (float, optional): The norm value for gradient clipping. Default is 0.0.
|
||||
|
||||
Examples:
|
||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||
>>> model = ...
|
||||
>>> criterion = ...
|
||||
>>> optimizer = ...
|
||||
>>> train_dataloader = ...
|
||||
>>> engine, _, _, _ = internlm.initialize_engine(model, optimizer, criterion)
|
||||
>>> engine.train()
|
||||
>>> for inputs, labels in train_dataloader
|
||||
>>> # set gradients to zero
|
||||
>>> engine.zero_grad()
|
||||
>>> # run forward pass
|
||||
>>> outputs = engine(inputs)
|
||||
>>> # compute loss value and run backward pass
|
||||
>>> loss = engine.criterion(outputs, labels)
|
||||
>>> engine.backward(loss)
|
||||
>>> # update parameters
|
||||
>>> engine.step()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Module,
|
||||
optimizer: BaseOptimizer,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
||||
criterion: Optional[_Loss] = None,
|
||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._lr_scheduler = lr_scheduler
|
||||
self._beta2_scheduler = beta2_scheduler
|
||||
self._criterion = criterion
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
# state
|
||||
self.training = True # default
|
||||
|
||||
# build gradient handler
|
||||
self._gradient_handlers = gradient_handlers if gradient_handlers else []
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Returns the model attached to the engine."""
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""Returns the optimizer attached to the engine."""
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
"""Returns the criterion (loss function) attached to the engine."""
|
||||
return self._criterion
|
||||
|
||||
def _all_reduce_gradients(self):
|
||||
"""Handles all-reduce operations of gradients across different parallel groups."""
|
||||
for handler in self._gradient_handlers:
|
||||
handler.handle_gradient()
|
||||
|
||||
def zero_grad(self):
|
||||
"""Sets the gradient of all parameters in the model to zero."""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
|
||||
and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
|
||||
if they exist.
|
||||
|
||||
Returns:
|
||||
success (bool): Whether the parameter update was successful.
|
||||
grad_norm (float): The norm of the gradient after clipping.
|
||||
"""
|
||||
self._all_reduce_gradients()
|
||||
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
||||
|
||||
success, grad_norm = self.optimizer.step()
|
||||
|
||||
if success and self._lr_scheduler is not None:
|
||||
self._lr_scheduler.step()
|
||||
|
||||
if success and self._beta2_scheduler is not None:
|
||||
self._beta2_scheduler.step()
|
||||
|
||||
return success, grad_norm
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode."""
|
||||
self.training = True
|
||||
self._model.train()
|
||||
|
||||
def eval(self):
|
||||
"""Sets the model to evaluation mode."""
|
||||
self.training = False
|
||||
self._model.eval()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
"""
|
||||
Starts the backward propagation given the loss value computed by a loss function.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss value computed by a loss function.
|
||||
"""
|
||||
return self.optimizer.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
"""
|
||||
Starts the backward propagation given the gradient of the output tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The output tensor.
|
||||
grad (torch.Tensor): The gradient passed back to the output tensor.
|
||||
"""
|
||||
return self.optimizer.backward_by_grad(tensor, grad)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Runs the forward step for the model.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output of the model.
|
||||
"""
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def load_batch(self, data_iter, to_gpu=True):
|
||||
"""
|
||||
Loads a batch from the data iterator. It returns the data and labels which are
|
||||
already in the same GPU as where the model is.
|
||||
|
||||
Args:
|
||||
data_iter (Iterable): The data iterator from which to get a batch of data, obtained by calling
|
||||
iter(dataloader).
|
||||
to_gpu (bool, optional): Whether the data should be moved to the GPU. Default is True.
|
||||
|
||||
Returns:
|
||||
Tuple (torch.Tensor, torch.Tensor): A tuple of (data, label).
|
||||
"""
|
||||
if data_iter is None:
|
||||
raise RuntimeError("Dataloader is not defined.")
|
||||
try:
|
||||
batch_data = next(data_iter)
|
||||
except TypeError:
|
||||
batch_data = data_iter
|
||||
|
||||
if to_gpu:
|
||||
batch_data = move_to_device(batch_data)
|
||||
batch_size = get_batch_size(batch_data)
|
||||
|
||||
return batch_data, batch_size
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
before optimization.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
|
||||
@abstractmethod
|
||||
def handle_gradient(self):
|
||||
"""A method to accumulate gradients across different parallel groups. Users should
|
||||
write their own functions or just use the functions in pre-defined subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in sub pipeline parallel groups."""
|
||||
if gpc.pipeline_parallel_size > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = defaultdict(lambda: defaultdict(list))
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
group = getattr(param, "pipeline_shared_module_pg", None)
|
||||
if (
|
||||
param.requires_grad
|
||||
and group is not None
|
||||
and (
|
||||
(hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
|
||||
or param.grad is not None
|
||||
)
|
||||
):
|
||||
tp = param.data.type()
|
||||
buckets[group][tp].append(param)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for group, group_buckets in buckets.items():
|
||||
for tp, bucket in group_buckets.items():
|
||||
grads = [
|
||||
param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
|
||||
for param in bucket
|
||||
]
|
||||
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
|
@ -0,0 +1,130 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor, nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
"""
|
||||
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
||||
It also provides options to cast the output back to fp32 and to synchronize buffers.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to be wrapped and cast into fp16.
|
||||
output_to_fp32 (bool, optional): If True, the output of this module is cast into fp32. Defaults to True.
|
||||
parallel_mode (:class:`internlm.core.context.ParallelMode`): The parallel group mode used in this module.
|
||||
Defaults to ``ParallelMode.DATA``.
|
||||
sync_buffer (bool, optional): If True, the buffers are synchronized. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
output_to_fp32: bool = True,
|
||||
parallel_mode: ParallelMode = ParallelMode.DATA,
|
||||
sync_buffer: bool = True,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model.to(dtype)
|
||||
self._output_to_fp32 = output_to_fp32
|
||||
self._sync_buf = sync_buffer
|
||||
self.dtype = dtype
|
||||
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
self._process_group = gpc.get_group(parallel_mode)
|
||||
self._world_size = gpc.get_world_size(parallel_mode)
|
||||
else:
|
||||
self._process_group = None
|
||||
self._world_size = 1
|
||||
self._sync_buf = False
|
||||
self._first_eval_run = False
|
||||
|
||||
@property
|
||||
def sync_buffer(self):
|
||||
"""Returns the current state of the buffer synchronization."""
|
||||
return self._sync_buf
|
||||
|
||||
@sync_buffer.setter
|
||||
def sync_buffer(self, state: bool):
|
||||
"""Sets the state of the buffer synchronization."""
|
||||
self._sync_buf = state
|
||||
|
||||
def _convert_to_fp16(self, input_: Any):
|
||||
"""Converts the input to fp16 if it is a Tensor of dtype float32."""
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
||||
input_ = input_.to(self.dtype)
|
||||
return input_
|
||||
|
||||
def _convert_to_fp32(self, input_: Any):
|
||||
"""Converts the input to fp32 if it is a Tensor of dtype float16."""
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
|
||||
input_ = input_.float()
|
||||
return input_
|
||||
|
||||
def _reduce_module_buffer(self):
|
||||
"""
|
||||
All-reduces the buffers (e.g., running stats of batch normalization) across
|
||||
data parallel ranks so that all the ranks will produce consistent results
|
||||
when given the same input.
|
||||
"""
|
||||
buf_list = []
|
||||
|
||||
# find valid buffers
|
||||
for buf in self.model.buffers():
|
||||
if buf is not None:
|
||||
buf_list.append(buf)
|
||||
|
||||
# reduce buffers across data parallel ranks
|
||||
if buf_list:
|
||||
coalesced_buf = _flatten_dense_tensors(buf_list)
|
||||
coalesced_buf.div_(self._world_size)
|
||||
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
|
||||
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
|
||||
for old, new in zip(buf_list, unflattened_buf_list):
|
||||
old.copy_(new)
|
||||
|
||||
def eval(self):
|
||||
"""Sets the model to evaluation mode. Buffers are only synchronized in the first eval iteration."""
|
||||
self.model.eval()
|
||||
self._first_eval_run = True
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Performs a forward pass on the model. Buffers are synchronized before the forward pass.
|
||||
The inputs are converted to fp16 and the outputs are optionally converted back to fp32.
|
||||
"""
|
||||
if (self.training or self._first_eval_run) and self._sync_buf:
|
||||
with torch.no_grad():
|
||||
self._reduce_module_buffer()
|
||||
|
||||
if self._first_eval_run:
|
||||
self._first_eval_run = False
|
||||
|
||||
if args:
|
||||
args = [self._convert_to_fp16(arg) for arg in args]
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
kwargs[k] = self._convert_to_fp16(v)
|
||||
|
||||
out = self.model(*args, **kwargs)
|
||||
|
||||
if self._output_to_fp32:
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
return out
|
|
@ -0,0 +1,279 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggregate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
|
||||
them into data and label.
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
@abstractmethod
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
forward_only (bool): If True, the process won't include backward.
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine: Engine, inputs: Any):
|
||||
"""Calls the engine with the given inputs.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
return engine(*inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
|
||||
"""Calls the engine's criterion with the given outputs and labels.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
|
||||
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = (labels,)
|
||||
|
||||
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
||||
return engine.criterion(*outputs, *labels)
|
||||
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
||||
return engine.criterion(*outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
||||
return engine.criterion(**outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
||||
|
||||
class NonPipelineScheduler(BaseScheduler):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
During one process, it loads a batch of dataset and feeds it to the model.
|
||||
After getting the output and calculating the loss, it will use :meth:`step`
|
||||
to update the parameters if it is in training mode.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
|
||||
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
|
||||
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
|
||||
gradient accumulation.
|
||||
|
||||
Example:
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(dataloader_output):
|
||||
item1, item2, item3 = dataloader_output
|
||||
data = (item1, item2)
|
||||
label = item3
|
||||
return data, label
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
||||
# check that non-pipeline schedule data process func only takes in one parameter
|
||||
# which is the batch data
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 1, (
|
||||
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
||||
"which is a tuple of tensors for the current batch, "
|
||||
"i.e. data_process_func(dataloader_output)."
|
||||
)
|
||||
|
||||
self._grad_accum_size = gradient_accumulation_size
|
||||
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""Performs actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _load_accum_batch(self, data: Any, label: Any):
|
||||
"""Loads a batch of data and label for gradient accumulation.
|
||||
|
||||
Args:
|
||||
data (Any): The data to be loaded.
|
||||
label (Any): The label to be loaded.
|
||||
"""
|
||||
_data = {
|
||||
k: v[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
|
||||
for k, v in data.items()
|
||||
}
|
||||
_label = label[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
|
||||
|
||||
self._grad_accum_offset += self._grad_accum_batch_size
|
||||
|
||||
return _data, _label
|
||||
|
||||
def _train_one_batch(
|
||||
self,
|
||||
data: Any,
|
||||
label: Any,
|
||||
engine: Engine,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
Args:
|
||||
data (Any): The data to be trained.
|
||||
label (Any): The label for the data.
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
|
||||
be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
scale_loss (int, optional): The scale factor for the loss.
|
||||
"""
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
output = self._call_engine(engine, data)
|
||||
|
||||
if return_loss:
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
loss /= scale_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
engine.backward(loss)
|
||||
|
||||
if not return_loss:
|
||||
loss = None
|
||||
|
||||
return output, loss
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||
forward_only (bool, optional):
|
||||
If True, the model is run for the forward pass, else back propagation will be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
return_output_label (bool, optional): Output and label will be returned if True.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
batch_data, batch_size = engine.load_batch(data_iter)
|
||||
|
||||
assert (
|
||||
batch_size == self._grad_accum_size
|
||||
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
|
||||
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
else:
|
||||
# if not batch data process func is given,
|
||||
# then we regard the batch data as a simple tuple of (data, label)
|
||||
data, label = batch_data
|
||||
|
||||
loss = 0 if return_loss else None
|
||||
outputs = []
|
||||
labels = []
|
||||
|
||||
# reset accumulation microbatch offset
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
for _current_accum_step in range(self._grad_accum_size):
|
||||
if _current_accum_step == self._grad_accum_size - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
else:
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
|
||||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
|
||||
outputs.append(_output)
|
||||
labels.append(_label)
|
||||
|
||||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
return outputs, labels, loss
|
|
@ -0,0 +1,155 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
import json
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
|
||||
|
||||
|
||||
class TrainState:
|
||||
"""
|
||||
The TrainState class is used to record the current state of training.
|
||||
|
||||
Args:
|
||||
train_dl (DataLoader): The DataLoader object used for training.
|
||||
"""
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
# The number of batches produced by the data iterator
|
||||
self.batch_count: int = 0
|
||||
# Used to store the number of samples consumed in the current epoch
|
||||
self.num_consumed_samples_in_epoch: int = 0
|
||||
# Total number of tokens consumed
|
||||
self.num_consumed_tokens: int = 0
|
||||
# Number of batches skipped due to inf or nan values
|
||||
self.inf_nan_skip_batches: int = 0
|
||||
# Records the number of updates, skipped batches and inf batches are not counted
|
||||
self.step_count: int = 0
|
||||
|
||||
# Total step count
|
||||
self.total_steps: int = config.data.total_steps
|
||||
|
||||
def init_batch_sampler(self, train_dl):
|
||||
# Copy of the batch sampler from the DataLoader
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
# Iterator for the batch sampler
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns a string representation of the training state in JSON format."""
|
||||
info = {
|
||||
"batch_count": self.batch_count,
|
||||
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
||||
"num_consumed_tokens": self.num_consumed_tokens,
|
||||
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
||||
"step_count": self.step_count,
|
||||
}
|
||||
|
||||
return json.dumps(info, indent=4, sort_keys=True)
|
||||
|
||||
def load_state_dict(self, other_stuffs, train_dl):
|
||||
"""
|
||||
Resumes training from a checkpoint.
|
||||
|
||||
Args:
|
||||
other_stuffs (dict): Other information needed to resume training.
|
||||
train_dl (DataLoader): The DataLoader object used for training.
|
||||
"""
|
||||
|
||||
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
||||
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
|
||||
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
|
||||
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
|
||||
# compatible with previous checkpoints without this parameter
|
||||
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
||||
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"batch_count": self.batch_count,
|
||||
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
||||
"num_consumed_tokens": self.num_consumed_tokens,
|
||||
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
||||
"step_count": self.step_count,
|
||||
}
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""This is a class tending for easy deployments of users' training and evaluation instead of
|
||||
writing their own scripts.
|
||||
|
||||
Args:
|
||||
engine (:class:`Engine`): Engine responsible for the process function.
|
||||
schedule (:class:`BaseScheduler`, optional): Runtime schedule. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
schedule: Optional[BaseScheduler] = None,
|
||||
):
|
||||
"""Initializes the Trainer class.
|
||||
|
||||
Args:
|
||||
engine (Engine): The engine responsible for the process function.
|
||||
schedule (Optional[BaseScheduler], optional): The runtime schedule. Defaults to None.
|
||||
"""
|
||||
self._engine = engine
|
||||
|
||||
# build schedule
|
||||
if schedule is None:
|
||||
self._schedule = NonPipelineScheduler()
|
||||
else:
|
||||
assert isinstance(
|
||||
schedule, BaseScheduler
|
||||
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
||||
self._schedule = schedule
|
||||
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
return self._schedule
|
||||
|
||||
@property
|
||||
def uses_pipeline(self):
|
||||
"""Returns whether the pipeline parallel is used or not."""
|
||||
return False
|
||||
|
||||
def train(self):
|
||||
self._engine.train()
|
||||
|
||||
def eval(self):
|
||||
self._engine.eval()
|
||||
|
||||
def zero_grad(self):
|
||||
self._engine.zero_grad()
|
||||
|
||||
def step(self):
|
||||
return self._engine.step()
|
||||
|
||||
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
||||
"""Runs the forward, loss computation, and backward for the model.
|
||||
Returns a tuple of (output, label, loss).
|
||||
|
||||
Args:
|
||||
data_iter (Iterable): The data iterator.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||
"""
|
||||
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
return output, label, loss
|
|
@ -0,0 +1,13 @@
|
|||
from .batch_sampler import get_dpsampler_dataloader
|
||||
from .collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||
from .dummy_dataset import RandomDataset
|
||||
from .packed_dataset import PackedDataset, PackedDatasetWithoutCuSeqlen
|
||||
|
||||
__all__ = [
|
||||
"jsonl_ds_collate_fn",
|
||||
"packed_collate_fn",
|
||||
"RandomDataset",
|
||||
"PackedDataset",
|
||||
"PackedDatasetWithoutCuSeqlen",
|
||||
"get_dpsampler_dataloader",
|
||||
]
|
|
@ -0,0 +1,359 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Iterator, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, Sampler
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class DataParallelSampler(Sampler):
|
||||
"""A data sampler for distributed data parallelism.
|
||||
|
||||
Args:
|
||||
dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
|
||||
shuffle (bool, optional): Whether to shuffle data, defaults to False.
|
||||
seed (int, optional): The random seed used for sampling, defaults to 0.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
|
||||
self.rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
# type: ignore[arg-type]
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
# `type:ignore` is required because Dataset cannot provide a default __len__
|
||||
# see NOTE in pytorch/torch/utils/data/sampler.py
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
# type: ignore[arg-type]
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
|
||||
# update for next epoch so that there is no need to call
|
||||
# set_epoch manually
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
||||
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
||||
sampler will yield the same ordering.
|
||||
|
||||
Args:
|
||||
epoch (int): Epoch number.
|
||||
"""
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def get_dpsampler_dataloader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
add_sampler=True,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
|
||||
Note:
|
||||
When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
|
||||
on the 1st stage and label on the last stage.
|
||||
|
||||
Args:
|
||||
dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
|
||||
if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
sampler = DataParallelSampler(dataset, shuffle=shuffle, drop_last=drop_last)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker():
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
if sampler is None:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
worker_init_fn=seed_worker,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
else:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class StaticBatchSampler:
|
||||
"""
|
||||
A static batch sampler that generates batches with a fixed micro-batch size.
|
||||
|
||||
Args:
|
||||
num_samples (int): The total number of samples in the dataset.
|
||||
batch_size (int): The batch size for the current rank. Defaults to 192.
|
||||
rampup_batch_size (str): A string with three space-separated integers representing the
|
||||
starting batch size, the increment, and the number of steps between
|
||||
each increment. For example, "192 24 8" means that the batch size
|
||||
starts at 192 and increases by 24 every 8 steps. Defaults to
|
||||
"6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
|
||||
micro_bsz (int): The micro-batch size. Defaults to 2.
|
||||
seed (int): The random seed for shuffling the indices. Defaults to 0.
|
||||
drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
|
||||
data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
|
||||
data_world_size (int): The number of processes in the data parallel group. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets,
|
||||
batch_size=192,
|
||||
rampup_batch_size="6 2 8",
|
||||
micro_bsz=2,
|
||||
seed=0,
|
||||
drop_last=True,
|
||||
data_rank=0,
|
||||
data_world_size=1,
|
||||
):
|
||||
assert drop_last is True, "Currently only support drop last"
|
||||
if rampup_batch_size:
|
||||
# In the process increase to batch_size
|
||||
start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
|
||||
else:
|
||||
start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
|
||||
self.raw_rampup_batch_size = rampup_batch_size
|
||||
self.start_bsz = start_bsz
|
||||
self.bsz_incre = bsz_incre
|
||||
self.incre_every = incre_every
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
assert (
|
||||
batch_size - self.start_bsz
|
||||
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
||||
assert (
|
||||
self.start_bsz // micro_bsz >= 4
|
||||
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
|
||||
`{micro_bsz}`, so that the pipeline can run correctly"
|
||||
|
||||
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.start_bsz % micro_bsz == 0
|
||||
), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.bsz_incre % micro_bsz == 0
|
||||
), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.epoch = 0
|
||||
self.seed = seed
|
||||
self.rng = np.random.RandomState(seed)
|
||||
self.batch_count = 0
|
||||
self.micro_bsz = micro_bsz
|
||||
self.data_rank = data_rank
|
||||
self.data_world_size = data_world_size
|
||||
self.num_consumed_samples_in_epoch = 0
|
||||
self.datasets = datasets
|
||||
self.num_samples = sum([len(ds) for ds in datasets])
|
||||
|
||||
self.get_indices() # get data
|
||||
|
||||
def get_indices(self, old_indices=None):
|
||||
if old_indices is not None:
|
||||
assert (
|
||||
len(old_indices) <= self.num_samples
|
||||
), f"The checkpoint has {len(old_indices)} samples, \
|
||||
while the new restart use less samples ({self.num_samples})"
|
||||
|
||||
else:
|
||||
old_indices = np.array([])
|
||||
|
||||
# indices includes len(old_indices) but not self.num_samples
|
||||
indices = np.arange(len(old_indices), self.num_samples)
|
||||
self.rng_state = self.rng.get_state()
|
||||
self.rng.shuffle(indices)
|
||||
# Need to consider drop_last
|
||||
ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre
|
||||
if self.batch_count < ramp_steps * self.incre_every:
|
||||
rampup_samples = 0
|
||||
for i in range(ramp_steps):
|
||||
rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every
|
||||
assert (
|
||||
rampup_samples * self.data_world_size <= self.num_samples
|
||||
), f"Too much rampup samples: \
|
||||
{rampup_samples*self.data_world_size} Vs. self.num_samples: {self.num_samples}"
|
||||
|
||||
num_samples = (self.num_samples - rampup_samples * self.data_world_size) // (
|
||||
self.batch_size * self.data_world_size
|
||||
)
|
||||
num_samples = num_samples * self.batch_size * self.data_world_size + rampup_samples * self.data_world_size
|
||||
else:
|
||||
num_samples = self.num_samples // (self.batch_size * self.data_world_size)
|
||||
num_samples = num_samples * self.batch_size * self.data_world_size
|
||||
indices = np.concatenate([old_indices, indices]).astype(int) # It needs to be spliced with the previous
|
||||
indices = indices[:num_samples]
|
||||
self.indices = indices
|
||||
assert len(self.indices) >= self.batch_size, "The number of samples should be larger than batch_size"
|
||||
self.num_consumed_samples_in_epoch = 0
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
self.rng = np.random.RandomState(self.seed + self.epoch)
|
||||
|
||||
def __len__(self):
|
||||
ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre
|
||||
if self.batch_count < ramp_steps * self.incre_every:
|
||||
rampup_samples = 0
|
||||
for i in range(ramp_steps):
|
||||
rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every
|
||||
assert (
|
||||
rampup_samples * self.data_world_size <= self.num_samples
|
||||
), f"Too much rampup samples: {rampup_samples*self.data_world_size} \
|
||||
Vs. self.num_samples: {self.num_samples}"
|
||||
|
||||
num_batches = (self.num_samples - rampup_samples * self.data_world_size) // self.batch_size
|
||||
num_batches = num_batches // self.data_world_size + self.incre_every * ramp_steps
|
||||
else:
|
||||
num_batches = self.num_samples // self.batch_size // self.data_world_size
|
||||
|
||||
return num_batches
|
||||
|
||||
def __iter__(self):
|
||||
indices = self.indices[self.data_rank :: self.data_world_size]
|
||||
while self.num_consumed_samples_in_epoch < len(indices):
|
||||
batch_rampup_idx = self.batch_count // self.incre_every
|
||||
cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz
|
||||
cur_batch_size = min(cur_batch_size, self.batch_size)
|
||||
batch = indices[self.num_consumed_samples_in_epoch : self.num_consumed_samples_in_epoch + cur_batch_size]
|
||||
yield batch
|
||||
self.num_consumed_samples_in_epoch += len(batch) # Consider multiple processes.
|
||||
self.batch_count += 1
|
||||
self.get_indices() # get a new round
|
||||
|
||||
def state_dict(self):
|
||||
states = {
|
||||
"batch_size": self.batch_size,
|
||||
"raw_rampup_batch_size": self.raw_rampup_batch_size,
|
||||
"rng_state": self.rng_state,
|
||||
"epoch": self.epoch,
|
||||
"seed": self.seed,
|
||||
"data_world_size": self.data_world_size,
|
||||
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
||||
"batch_count": self.batch_count, # The batch_count here is due to the existence of multiple processes,
|
||||
# the batch may be oversent, and it needs to be overwritten by the external batch_count
|
||||
"indices": self.indices, # The sequence used to breakpoint retraining is the same as before
|
||||
}
|
||||
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states):
|
||||
for name in ("data_world_size", "raw_rampup_batch_size", "seed"): # 'batch_size'
|
||||
assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change
|
||||
self.rng.set_state(states["rng_state"])
|
||||
self.get_indices(old_indices=None) # Regenerate indices based on random state
|
||||
self.epoch = states["epoch"]
|
||||
self.batch_count = states["batch_count"]
|
||||
self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"]
|
||||
|
||||
def copy(self):
|
||||
copy_sampler = StaticBatchSampler(
|
||||
self.datasets,
|
||||
self.batch_size,
|
||||
self.raw_rampup_batch_size,
|
||||
self.micro_bsz,
|
||||
self.seed,
|
||||
drop_last=True,
|
||||
data_rank=self.data_rank,
|
||||
data_world_size=self.data_world_size,
|
||||
)
|
||||
|
||||
copy_sampler.load_state_dict(self.state_dict())
|
||||
return copy_sampler
|
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def packed_collate_fn(batch, packed_length):
|
||||
|
||||
"""
|
||||
Collate function for packed input sequences.
|
||||
|
||||
Args:
|
||||
batch (List[Dict]): List of dictionaries representing each sample in batch.
|
||||
Each dictionary contains "tokens", "labels", "type_ids", "cu_seqlens", and "indexes" keys.
|
||||
packed_length (int): The length of packed sequence.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
|
||||
"cu_seqlens", "indexes", and "type_ids" keys, and the tensor of padded "labels".
|
||||
|
||||
Raises:
|
||||
AssertionError: If the length of a sample is not equal to packed_length.
|
||||
AssertionError: If the shape of the padded "input_ids" tensor does not have the correct shape.
|
||||
"""
|
||||
|
||||
xs, ys, cu_seqlens, indexes, ts = [], [], [], [], []
|
||||
for b in batch:
|
||||
assert (
|
||||
len(b["tokens"]) == packed_length
|
||||
), f"length of a sample should be equal to packed_length, but got {len(b['tokens'])} and {packed_length})"
|
||||
assert (
|
||||
len(b["labels"]) == packed_length
|
||||
), f"length of a sample should be equal to packed_length, but got {len(b['labels'])} and {packed_length})"
|
||||
assert (
|
||||
len(b["type_ids"]) == packed_length
|
||||
), f"length of a sample should be equal to packed_length, but got {len(b['type_ids'])} and {packed_length})"
|
||||
|
||||
tokens = [abs(w) for w in b["tokens"]]
|
||||
labels = [w if w > 0 else -100 for w in b["labels"]]
|
||||
|
||||
xs.append(torch.LongTensor(tokens))
|
||||
# The labels have been shifted here, so they are aligned with the output corresponding to the token
|
||||
ys.append(torch.LongTensor(labels))
|
||||
ts.append(torch.LongTensor(b["type_ids"]))
|
||||
cu_seqlens.append(torch.IntTensor(b["cu_seqlens"]))
|
||||
indexes.append(torch.LongTensor(b["indexes"]))
|
||||
|
||||
xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
|
||||
ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100)
|
||||
ts = torch.nn.utils.rnn.pad_sequence(ts, batch_first=True, padding_value=0)
|
||||
indexes = torch.stack(indexes, dim=0)
|
||||
if len(set(map(len, cu_seqlens))) == 1: # if has uniform length, then stack to save device transfer time
|
||||
cu_seqlens = torch.stack(cu_seqlens, dim=0)
|
||||
|
||||
assert xs.shape[1] == packed_length, (xs.shape[1], packed_length)
|
||||
|
||||
return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts}, ys
|
||||
|
||||
|
||||
def jsonl_ds_collate_fn(batch, max_length_per_sample):
|
||||
"""
|
||||
Collate function for json dataset.
|
||||
|
||||
Args:
|
||||
batch (List[Dict]): List of dictionaries representing each sample in batch.
|
||||
Each dictionary contains "tokens".
|
||||
max_length_per_sample (int): The length of output sequence.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
|
||||
and the tensor of padded "labels".
|
||||
|
||||
"""
|
||||
xs, ys = [], []
|
||||
for x in batch:
|
||||
x["tokens"] = x["tokens"][:max_length_per_sample]
|
||||
tokens = [abs(w) for w in x["tokens"]]
|
||||
labels = [w if w > 0 else -100 for w in x["tokens"]]
|
||||
labels = labels[1:] + [-100]
|
||||
xs.append(torch.as_tensor(tokens))
|
||||
ys.append(torch.as_tensor(labels)) # y has been shifted
|
||||
xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
|
||||
ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100)
|
||||
|
||||
xs = torch.cat([xs, xs.new_zeros(len(xs), max_length_per_sample - len(xs[0]))], dim=-1)
|
||||
ys = torch.cat([ys, ys.new_full((len(ys), max_length_per_sample - len(ys[0])), fill_value=-100)], dim=-1)
|
||||
|
||||
return {"input_ids": xs}, ys
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
"""
|
||||
RandomDataset for generating random dataset.
|
||||
|
||||
Args:
|
||||
num_samples (int): The number of samples to generate.
|
||||
max_len (int): The maximum length of each sample.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_samples=10000, max_len=1024) -> None:
|
||||
super().__init__()
|
||||
rng = np.random.RandomState(1999)
|
||||
max_num = rng.randint(1, 30, size=(num_samples,))
|
||||
rep_num = rng.randint(10, 200, size=(num_samples,))
|
||||
data = []
|
||||
lengths = []
|
||||
for n, r in zip(max_num, rep_num):
|
||||
d = list(range(n)) * r
|
||||
d = [n, r] + d
|
||||
d = d[:max_len]
|
||||
data.append(d)
|
||||
lengths.append(len(d))
|
||||
self.data = data
|
||||
self.max_len = max_len
|
||||
self.lengths = np.array(lengths, dtype=int)
|
||||
|
||||
def __getitem__(self, index):
|
||||
d = self.data[index]
|
||||
input_ids = np.array(d, dtype=int)
|
||||
return {"tokens": list(input_ids), "type_id": 0}
|
||||
|
||||
def get_dataset_name(self):
|
||||
return "dummy_path/dummy_lang/dummy_ds/train.bin"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
|
@ -0,0 +1,376 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import itertools as it
|
||||
import operator
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.data.single_dataset import JsonlDataset
|
||||
from internlm.data.utils import get_dataset_type_id
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
DEFAULT_SEED = 1024
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class PackedDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
The class PackedDataset takes in a dataset and aggregates samples of different
|
||||
lengths together based on the packed_length.
|
||||
|
||||
Args:
|
||||
dataset: The original dataset to pack.
|
||||
max_length_per_sample: The maximum length of each original sample. Default is 2048.
|
||||
packed_length: The length of each packed sample. Default is 4096.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
max_length_per_sample: int = 2048,
|
||||
packed_length: int = 4096,
|
||||
):
|
||||
assert hasattr(dataset, "lengths")
|
||||
assert len(getattr(dataset, "lengths")) == len(
|
||||
dataset
|
||||
), "The dataset must have lengths attribute and have the same length as the dataset"
|
||||
self.dataset = dataset
|
||||
self.max_length_per_sample = max_length_per_sample
|
||||
self.lengths = getattr(self.dataset, "lengths")
|
||||
self.packed_length = packed_length
|
||||
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
|
||||
|
||||
self.seed = DEFAULT_SEED
|
||||
self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed)
|
||||
self.num_tokens = sum(self.lengths)
|
||||
|
||||
def get_dataset_name(self):
|
||||
return self.dataset.get_dataset_name()
|
||||
|
||||
def accu_sample_len(self, seed=None):
|
||||
"""accumulative length of samples"""
|
||||
if seed is not None:
|
||||
rng = np.random.RandomState(seed)
|
||||
else:
|
||||
rng = np.random.RandomState(self.seed - 1)
|
||||
|
||||
sample_indices = np.arange(len(self.lengths))
|
||||
rng.shuffle(sample_indices)
|
||||
len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices))
|
||||
acm_len_samples = list(it.accumulate(len_samples_shuffled, operator.add))
|
||||
return sample_indices, len_samples_shuffled, acm_len_samples
|
||||
|
||||
def __len__(self):
|
||||
# Line 405 of document_to_sequence.py in metaseq is directly spliced,
|
||||
# without additional consideration of sos or eos
|
||||
n_packs = self.num_tokens // self.packed_length
|
||||
return n_packs
|
||||
|
||||
def cal_map(self, carriage_idx: int = 0):
|
||||
assert carriage_idx >= 0
|
||||
length_train = (carriage_idx + 1) * self.packed_length
|
||||
post_pos = np.searchsorted(self.acm_len_samples, length_train, side="left")
|
||||
return post_pos
|
||||
|
||||
def mapping(self, pack_idx: int = 0):
|
||||
# pack_idx is zero-based
|
||||
pre_pos, pre_token_id = 0, 0
|
||||
if pack_idx > 0:
|
||||
pre_pos = self.cal_map(pack_idx - 1)
|
||||
pre_token_id = self.len_samples_shuffled[pre_pos] - (
|
||||
self.acm_len_samples[pre_pos] - (pack_idx) * self.packed_length
|
||||
)
|
||||
if pre_token_id == self.len_samples_shuffled[pre_pos]:
|
||||
pre_pos += 1
|
||||
pre_token_id = 0
|
||||
|
||||
pos = self.cal_map(pack_idx)
|
||||
token_id = self.len_samples_shuffled[pos] - (self.acm_len_samples[pos] - (pack_idx + 1) * self.packed_length)
|
||||
return pre_pos, pre_token_id, pos, token_id
|
||||
|
||||
def build_pack(self, pre_pos: int, pre_token_id: int, pos: int, token_id: int):
|
||||
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
||||
|
||||
while pre_pos < pos:
|
||||
sample_idx = self.sample_indices[pre_pos]
|
||||
sample = self.dataset[sample_idx]
|
||||
chunk = sample["tokens"][pre_token_id:]
|
||||
pack.extend(chunk)
|
||||
_labels = deepcopy(chunk)
|
||||
_labels = list(_labels[1:]) + [-100]
|
||||
assert len(_labels) == len(chunk), (_labels, chunk)
|
||||
labels.extend(_labels)
|
||||
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
||||
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
|
||||
for _ in range(num_new_samples):
|
||||
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
|
||||
indexes.extend(list(range(self.max_length_per_sample)))
|
||||
if tokens_left > 0:
|
||||
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
|
||||
indexes.extend(list(range(tokens_left)))
|
||||
pre_pos = pre_pos + 1
|
||||
pre_token_id = 0
|
||||
|
||||
sample_idx = self.sample_indices[pos]
|
||||
sample = self.dataset[sample_idx]
|
||||
chunk = sample["tokens"][pre_token_id:token_id] # fragement of a sample
|
||||
pack.extend(chunk)
|
||||
_labels = deepcopy(chunk)
|
||||
if token_id == len(sample["tokens"]):
|
||||
_labels = list(_labels[1:]) + [-100]
|
||||
else:
|
||||
if token_id > len(sample["tokens"]):
|
||||
print(f"token_id {token_id}, len of sample {len(sample['tokens'])}")
|
||||
_labels = list(_labels[1:]) + [sample["tokens"][token_id]]
|
||||
assert len(_labels) == len(chunk), (_labels, chunk)
|
||||
labels.extend(_labels)
|
||||
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
||||
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
|
||||
for _ in range(num_new_samples):
|
||||
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
|
||||
indexes.extend(list(range(self.max_length_per_sample)))
|
||||
if tokens_left > 0:
|
||||
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
|
||||
indexes.extend(list(range(tokens_left)))
|
||||
|
||||
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||
return out
|
||||
|
||||
def __getitem__(self, item: int) -> Dict:
|
||||
"""Given the index, it returns a dict as
|
||||
{
|
||||
'tokens': List[int],
|
||||
'cu_seqlens': List[int],
|
||||
'indexes': List[int], # denotes positional vector as 'tokens'
|
||||
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
|
||||
}
|
||||
"""
|
||||
|
||||
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
||||
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
||||
|
||||
|
||||
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
||||
"""
|
||||
A dataset wrapper that aggregates samples with different lengths based on packed_length.
|
||||
If a sample is shorter than max_length_per_sample, it will be merged with other samples.
|
||||
For example, given a dataset with 10 samples:
|
||||
[1, 2, 3, 4, 5]
|
||||
[6, 7]
|
||||
[8, 9, 10, 11]
|
||||
[12, ..., 100]
|
||||
...
|
||||
|
||||
Args:
|
||||
dataset: The original dataset to be wrapped.
|
||||
max_length_per_sample (int): The maximum length allowed for each sample.
|
||||
packed_length (int): The desired length for each packed sample.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
max_length_per_sample: int = 2048,
|
||||
packed_length: int = 4096,
|
||||
debug=False,
|
||||
):
|
||||
assert packed_length % max_length_per_sample == 0
|
||||
assert hasattr(dataset, "lengths")
|
||||
assert len(getattr(dataset, "lengths")) == len(
|
||||
dataset
|
||||
), "The dataset must have lengths attribute and have the same length as the dataset"
|
||||
self.dataset = dataset
|
||||
self.max_length_per_sample = max_length_per_sample
|
||||
self.lengths = getattr(self.dataset, "lengths")
|
||||
self.bsz = packed_length // max_length_per_sample
|
||||
self.packed_length = packed_length
|
||||
self.debug = debug
|
||||
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
|
||||
|
||||
self.seed = DEFAULT_SEED
|
||||
indices = np.arange(len(self.lengths))
|
||||
rng = np.random.RandomState(self.seed)
|
||||
rng.shuffle(indices)
|
||||
self.indices = indices
|
||||
self.cum_lens = np.cumsum(self.lengths[self.indices])
|
||||
self.num_tokens = sum(self.lengths)
|
||||
|
||||
def get_dataset_name(self):
|
||||
return self.dataset.get_dataset_name()
|
||||
|
||||
def __len__(self):
|
||||
n_packs = self.num_tokens // self.packed_length
|
||||
return n_packs
|
||||
|
||||
def find_offset(self, offset):
|
||||
idx = np.searchsorted(self.cum_lens, offset, side="right")
|
||||
if idx == 0:
|
||||
return idx, offset
|
||||
length = offset - self.cum_lens[idx - 1]
|
||||
return idx, length
|
||||
|
||||
def pdebug(self, line):
|
||||
if self.debug:
|
||||
print(line, flush=True)
|
||||
|
||||
def __getitem__(self, item: int) -> Dict:
|
||||
"""Given the index, it returns a dict as
|
||||
{
|
||||
'tokens': List[int],
|
||||
'cu_seqlens': List[int],
|
||||
'indexes': List[int], # denotes positional vector as 'tokens'
|
||||
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
|
||||
}
|
||||
"""
|
||||
|
||||
start_idx, start_length = self.find_offset(item * self.packed_length)
|
||||
end_idx, end_length = self.find_offset((item + 1) * self.packed_length)
|
||||
pack_tokens = []
|
||||
pack_labels = []
|
||||
type_ids = []
|
||||
|
||||
self.pdebug(f"item : {item}, start_idx:{start_idx}, start_length:{start_length} ")
|
||||
self.pdebug(f"item : {item}, end_idx:{end_idx}, end_length:{end_length} ")
|
||||
|
||||
if start_idx == end_idx:
|
||||
idx = self.indices[start_idx]
|
||||
sample = self.dataset[idx]
|
||||
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
||||
tokens = sample["tokens"][start_length:end_length]
|
||||
pack_tokens.extend(tokens)
|
||||
pack_labels.extend(tokens[1:] + [-100])
|
||||
type_ids.extend([sample["type_id"]] * len(tokens))
|
||||
return {
|
||||
"tokens": pack_tokens,
|
||||
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
|
||||
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
|
||||
"labels": pack_labels,
|
||||
"type_ids": type_ids,
|
||||
}
|
||||
|
||||
idx = self.indices[start_idx]
|
||||
sample = self.dataset[idx]
|
||||
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
||||
tokens = sample["tokens"][start_length:]
|
||||
pack_tokens.extend(tokens)
|
||||
pack_labels.extend(tokens[1:] + [-100])
|
||||
type_ids.extend([sample["type_id"]] * len(tokens))
|
||||
|
||||
for i in range(start_idx + 1, end_idx):
|
||||
idx = self.indices[i]
|
||||
sample = self.dataset[idx]
|
||||
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
||||
tokens = sample["tokens"]
|
||||
pack_tokens.extend(tokens)
|
||||
pack_labels.extend(tokens[1:] + [-100])
|
||||
type_ids.extend([sample.get("type_id")] * len(tokens))
|
||||
|
||||
# corner case, the last sample is useless
|
||||
if end_length == 0:
|
||||
pass
|
||||
else:
|
||||
idx = self.indices[end_idx]
|
||||
sample = self.dataset[idx]
|
||||
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
||||
tokens = sample["tokens"][:end_length]
|
||||
pack_tokens.extend(tokens)
|
||||
pack_labels.extend(tokens[1:] + [-100])
|
||||
type_ids.extend([sample.get("type_id")] * len(tokens))
|
||||
|
||||
return {
|
||||
"tokens": pack_tokens,
|
||||
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
|
||||
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
|
||||
"labels": pack_labels,
|
||||
"type_ids": type_ids,
|
||||
}
|
||||
|
||||
|
||||
def get_packed_dataset_without_short_length(
|
||||
folder,
|
||||
max_length_per_sample=2048,
|
||||
packed_length=4096,
|
||||
show_progress=False,
|
||||
min_length=50,
|
||||
min_length_dict=None,
|
||||
pack_into_one_sample=False,
|
||||
):
|
||||
"""
|
||||
Given a folder, combine all the .bin files into a single large dataset.
|
||||
And filter out short samples with length less than 'min_length'.
|
||||
|
||||
Each .bin file is treated as a separate dataset.
|
||||
|
||||
Args:
|
||||
folder (str): Path to the folder containing the .bin files.
|
||||
max_length_per_sample (int): Maximum length of each sample.
|
||||
packed_length (int): Length to pack samples to.
|
||||
show_progress (bool): Whether to show the progress bar.
|
||||
min_length (int): The minimum length of the sample.
|
||||
min_length_dict (dict): The minimum length of the sample for each dataset.
|
||||
The format is something like {'pile-arxiv': 50}
|
||||
dataset_backend (Optional[str]): Dataset storage location. Optional parameters are local, local-shm, kv
|
||||
|
||||
Returns:
|
||||
A packed dataset containing all the data from the .bin files.
|
||||
"""
|
||||
|
||||
assert os.path.exists(folder), f"{folder} does not exist."
|
||||
datasets = []
|
||||
delete_samples = 0
|
||||
|
||||
for root, dirs, files in os.walk(folder, followlinks=True):
|
||||
dirs.sort() # Let the folder need to be returned in a fixed order
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Reading {root}...")
|
||||
num_token_in_folder = 0
|
||||
|
||||
for fn in tqdm(sorted(files), total=len(files), leave=False, disable=not show_progress):
|
||||
if fn.endswith(".bin"):
|
||||
fp = os.path.join(root, fn)
|
||||
catch_ml_keys = []
|
||||
min_length_num = min_length
|
||||
if min_length_dict is not None:
|
||||
for k, v in min_length_dict.items():
|
||||
if k in fp:
|
||||
min_length_num = v
|
||||
catch_ml_keys.append(k)
|
||||
assert (
|
||||
len(catch_ml_keys) < 2
|
||||
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
|
||||
|
||||
ds_type_id = get_dataset_type_id(path=fp)
|
||||
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
|
||||
|
||||
if hasattr(ds, "old_length"):
|
||||
delete_samples += ds.old_length - len(ds)
|
||||
if len(ds) == 0:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"None of the data in `{fp}` is longer than {min_length}")
|
||||
continue
|
||||
|
||||
if pack_into_one_sample:
|
||||
ds = PackedDatasetWithoutCuSeqlen(ds, max_length_per_sample, packed_length)
|
||||
else:
|
||||
ds = PackedDataset(ds, max_length_per_sample, packed_length)
|
||||
|
||||
num_token_in_folder += len(ds) * packed_length
|
||||
datasets.append(ds)
|
||||
|
||||
dataset = ConcatDataset(datasets=datasets)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"Find `{len(datasets)}` datasets, \
|
||||
{len(dataset)} samples, \
|
||||
delete `{delete_samples}` because of short length",
|
||||
)
|
||||
|
||||
return dataset
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
A .bin file corresponds to a Dataset instance here.
|
||||
"""
|
||||
|
||||
import json
|
||||
import mmap
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class JsonlDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
||||
JSONL format is expected to roughly follow that of The Pile.
|
||||
One-line-per-document of the form:
|
||||
```
|
||||
{
|
||||
"tokens": List[int],
|
||||
}
|
||||
```
|
||||
|
||||
Note that only the "tokens" key is used.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, dataset_type_id: int = 0, min_length=50):
|
||||
self.path = path
|
||||
self.threadlocal = threading.local()
|
||||
resolved_path = Path(path).resolve()
|
||||
self.resolved_path = resolved_path
|
||||
self.meta = Path(f"{resolved_path}.meta")
|
||||
self.type_id = dataset_type_id
|
||||
|
||||
# only build the cache in on the primary worker to prevent overloading nfs
|
||||
assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}"
|
||||
try:
|
||||
with open(self.meta, "rb") as f:
|
||||
meta = np.load(f)
|
||||
except Exception as e:
|
||||
print(f"Cannot load file {self.meta}...")
|
||||
raise e
|
||||
self.offsets = meta[:, 0]
|
||||
self.lengths = meta[:, -1]
|
||||
|
||||
if min_length > 0:
|
||||
mask = self.lengths >= min_length
|
||||
self.old_lengths = self.lengths.copy()
|
||||
self.old_length = len(self.offsets)
|
||||
self.offsets = self.offsets[mask]
|
||||
self.lengths = self.lengths[mask]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
f = self._get_mmap()
|
||||
position = self.offsets[idx]
|
||||
f.seek(position)
|
||||
item = f.readline().decode("utf-8")
|
||||
try:
|
||||
item = json.loads(item)
|
||||
item["length"] = len(item["tokens"]) # add a length info
|
||||
item["type_id"] = self.type_id
|
||||
except Exception as err:
|
||||
raise json.decoder.JSONDecodeError(
|
||||
doc=self.path,
|
||||
pos=position,
|
||||
msg=(
|
||||
f"Error while loading JSONL line in file {self.path} at byte "
|
||||
f"{position}. Contents of line:\n{item}\n{err}"
|
||||
),
|
||||
)
|
||||
return item
|
||||
|
||||
def get_dataset_name(self):
|
||||
return str(self.resolved_path)
|
||||
|
||||
def _get_mmap(self):
|
||||
if not hasattr(self.threadlocal, "handles"):
|
||||
with open(self.path, "rb") as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
self.threadlocal.handles = [f, mm]
|
||||
if self.path.endswith(".gz") or self.path.endswith(".bz") or self.path.endswith(".bz2"):
|
||||
raise NotImplementedError(
|
||||
"Compressed files are not supported because .seek() would require "
|
||||
"rereading the entire file, making performance too slow."
|
||||
)
|
||||
return self.threadlocal.handles[-1]
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.threadlocal = threading.local()
|
||||
|
||||
def __getstate__(self):
|
||||
d = {}
|
||||
for i, v in self.__dict__.items():
|
||||
if i != "threadlocal":
|
||||
d[i] = v
|
||||
return d
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self.threadlocal, "handles"):
|
||||
# cleanup files we opened on initialization
|
||||
while self.threadlocal.handles:
|
||||
self.threadlocal.handles.pop().close()
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(path)
|
||||
|
||||
def __len__(self):
|
||||
# Virtual length of the dataset depends on the epoch number if the number of documents
|
||||
# is not perfectly divisible by the data_subshard_count
|
||||
return len(self.offsets)
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
|
||||
|
||||
|
||||
def get_dataset_type_id(path):
|
||||
import re
|
||||
|
||||
match_idxes = []
|
||||
for key, idx in DATASET_TYPE_IDS_MAP.items():
|
||||
if re.search(rf"/[z_]*{key}/", path):
|
||||
match_idxes.append(idx)
|
||||
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
||||
return match_idxes[0]
|
|
@ -0,0 +1,9 @@
|
|||
from .initialize_trainer import initialize_trainer
|
||||
from .launch import get_default_parser, launch_from_slurm, launch_from_torch
|
||||
|
||||
__all__ = [
|
||||
"get_default_parser",
|
||||
"initialize_trainer",
|
||||
"launch_from_slurm",
|
||||
"launch_from_torch",
|
||||
]
|
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
def scaled_init_method_normal(sigma, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def normal_(mean: float = 0.0, std: float = 1.0):
|
||||
r"""Return the initializer filling the input Tensor with values drawn from the normal distribution
|
||||
|
||||
.. math::
|
||||
\mathcal{N}(\text{mean}, \text{std}^2)
|
||||
|
||||
Args:
|
||||
mean (float): the mean of the normal distribution. Defaults 0.0.
|
||||
std (float): the standard deviation of the normal distribution. Defaults 1.0.
|
||||
"""
|
||||
|
||||
def initializer(tensor: Tensor):
|
||||
return nn.init.normal_(tensor, mean, std)
|
||||
|
||||
return initializer
|
|
@ -0,0 +1,84 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
||||
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.no_pipeline_scheduler import NonPipelineScheduler
|
||||
from internlm.core.trainer import Trainer
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
|
||||
def initialize_trainer(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
||||
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
|
||||
optimizer (:class:`BaseOptimizer`.
|
||||
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
||||
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
||||
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
||||
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
||||
|
||||
Returns:
|
||||
Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
|
||||
A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
|
||||
where only ``trainer`` could not be None.
|
||||
"""
|
||||
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
elif isinstance(model, Callable):
|
||||
model = model().to(get_current_device())
|
||||
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
|
||||
|
||||
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
||||
|
||||
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
||||
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
||||
gradient_handlers = []
|
||||
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
||||
for config in gradient_handler_cfg:
|
||||
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
||||
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||
gradient_handlers.append(handler)
|
||||
|
||||
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
)
|
||||
|
||||
trainer = Trainer(engine, scheduler)
|
||||
|
||||
return trainer, train_dataloader, test_dataloader, lr_scheduler
|
|
@ -0,0 +1,296 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import Config
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
"""Reads user command line and uses an argument parser to parse the input arguments.
|
||||
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
||||
|
||||
Returns:
|
||||
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, help="path to the config file")
|
||||
parser.add_argument(
|
||||
"--launcher",
|
||||
type=str,
|
||||
default="slurm",
|
||||
choices=["slurm", "torch"],
|
||||
help="launcher for launching distributed environment",
|
||||
)
|
||||
parser.add_argument("--host", type=str, help="the master address for distributed training")
|
||||
parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
|
||||
parser.add_argument("--world_size", type=int, help="world size for distributed training")
|
||||
parser.add_argument("--rank", type=int, help="rank for the default process group")
|
||||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||
parser.add_argument("--seed", type=int, default=1024)
|
||||
return parser
|
||||
|
||||
|
||||
def args_sanity_check():
|
||||
assert gpc.config is not None, "config is not load!"
|
||||
|
||||
# the default model type is INTERNLM
|
||||
if "model_type" not in gpc.config:
|
||||
gpc.config._add_item("model_type", "INTERNLM")
|
||||
|
||||
# procssing the parallel config in gpc
|
||||
if "zero1" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("zero1", -1)
|
||||
|
||||
if "pipeline" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("pipeline", 1)
|
||||
|
||||
if "tensor" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("tensor", 1)
|
||||
|
||||
# processing the data config in gpc
|
||||
data = gpc.config.data
|
||||
|
||||
assert data.seq_len is not None, "'seq_len' must be given a value"
|
||||
assert data.micro_bsz is not None, "'micro_bsz' must be given a value"
|
||||
|
||||
if "packed_length" in data and gpc.is_rank_for_log():
|
||||
logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.")
|
||||
|
||||
data._add_item("packed_length", data.seq_len * data.micro_bsz)
|
||||
|
||||
if "micro_num" not in data:
|
||||
data._add_item("micro_num", 1)
|
||||
|
||||
data._add_item("gradient_accumulation", data.micro_num)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
|
||||
|
||||
# batch_size should be equal with micro_num, should not use it directly
|
||||
data._add_item("batch_size", data.micro_num)
|
||||
|
||||
if "min_length" not in data:
|
||||
data._add_item("min_length", 0)
|
||||
|
||||
if "train_folder" not in data:
|
||||
data._add_item("train_folder", None)
|
||||
|
||||
if "valid_folder" not in data:
|
||||
data._add_item("valid_folder", None)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+++++++++++++++++++++++++++++++ Data Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"seq_len: {data.seq_len}")
|
||||
logger.info(f"micro_num: {data.micro_num}")
|
||||
logger.info(f"micro_bsz: {data.micro_bsz}")
|
||||
logger.info(f"packed_length: {data.packed_length}")
|
||||
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
|
||||
logger.info(f"min_length: {data.min_length}")
|
||||
|
||||
# processing the checkpoint config
|
||||
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
|
||||
gpc.config.ckpt._add_item("checkpoint_every", float("inf"))
|
||||
|
||||
if "load_optimizer" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_optimizer", True)
|
||||
|
||||
if "save_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("save_ckpt_folder", None)
|
||||
|
||||
if "load_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_ckpt_folder", None)
|
||||
|
||||
if "load_model_only_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_model_only_folder", None)
|
||||
|
||||
assert not (
|
||||
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
|
||||
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
|
||||
|
||||
gpc.config.ckpt._add_item(
|
||||
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
|
||||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+++++++++++++++++++++++++++++++ Ckpt Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
||||
|
||||
# cudnn
|
||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
|
||||
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+++++++++++++++++++++++++++++++ other Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"cudnn.benchmark: {torch.backends.cudnn.benchmark }")
|
||||
logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
|
||||
logger.info(f"clip_grad_norm: {clip_grad_norm}")
|
||||
|
||||
if "dtype" not in gpc.config.model:
|
||||
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
||||
gpc.config.model._add_item("dtype", torch.float16)
|
||||
else:
|
||||
if gpc.config.model.dtype == "torch.bfloat16":
|
||||
gpc.config.model.dtype = torch.bfloat16
|
||||
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
||||
gpc.config.model.dtype = torch.float16
|
||||
else:
|
||||
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+++++++++++++++++++++++++++++++ Model Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"Model: {gpc.config.model}")
|
||||
|
||||
logger.info("+++++++++++++++++++++++++++++++ grad_scaler Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"grad_scaler: {gpc.config.grad_scaler}")
|
||||
|
||||
logger.info("+++++++++++++++++++++++++++++++ hybrid_zero_optimizer Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}")
|
||||
|
||||
logger.info("+++++++++++++++++++++++++++++++ adam Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"adam: {gpc.config.adam}")
|
||||
|
||||
logger.info("+++++++++++++++++++++++++++++++ beta2_scheduler Info +++++++++++++++++++++++++++++++")
|
||||
logger.info(f"beta2_scheduler: {gpc.config.beta2_scheduler}")
|
||||
logger.info("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
local_rank: int = None,
|
||||
seed: int = 1024,
|
||||
):
|
||||
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
||||
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
rank (int): Rank for the default process group
|
||||
world_size (int): World size of the default process group
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
local_rank (int, optional):
|
||||
Rank for the process on the node and is used to set the default CUDA device,
|
||||
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
|
||||
Raises:
|
||||
Exception: Raise exception when config type is wrong
|
||||
"""
|
||||
|
||||
# set config
|
||||
assert isinstance(
|
||||
config, (Config, str, Path, dict)
|
||||
), f"expected argument config to be Config, str or Path, but got {type(config)}"
|
||||
if not isinstance(config, Config) and isinstance(config, dict):
|
||||
config = Config(config)
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.from_file(config)
|
||||
gpc.load_config(config)
|
||||
|
||||
# init default process group
|
||||
gpc.init_global_dist(rank, world_size, backend, host, port)
|
||||
|
||||
# init process groups for different parallel modes from config
|
||||
gpc.init_parallel_groups()
|
||||
|
||||
args_sanity_check()
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
# if local rank is not given, calculate automatically
|
||||
gpc.set_device(local_rank)
|
||||
|
||||
# set the number of processes running on the same node
|
||||
gpc.detect_num_processes_on_current_node()
|
||||
|
||||
gpc.set_seed(seed)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"Distributed environment is initialized, "
|
||||
f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, "
|
||||
f"tensor parallel size: {gpc.tensor_parallel_size}",
|
||||
)
|
||||
|
||||
|
||||
def launch_from_slurm(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
seed: int = 1024,
|
||||
):
|
||||
"""A wrapper for internlm.launch for SLURM launcher by reading rank and world size from the environment variables
|
||||
set by SLURM
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ["SLURM_PROCID"])
|
||||
world_size = int(os.environ["SLURM_NPROCS"])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Could not find {e} in the SLURM environment")
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024):
|
||||
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
host = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Could not find {e} in the torch environment")
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
)
|
|
@ -0,0 +1,19 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .embedding import Embedding1D, RotaryEmbedding
|
||||
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
|
||||
from .modeling_internlm import build_model_with_cfg
|
||||
from .multi_head_attention import MHA
|
||||
from .utils import gather_forward_split_backward
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
"FeedForward",
|
||||
"RotaryEmbedding",
|
||||
"RewardModelLinear",
|
||||
"ScaleColumnParallelLinear",
|
||||
"MHA",
|
||||
"gather_forward_split_backward",
|
||||
"build_model_with_cfg",
|
||||
]
|
|
@ -0,0 +1,209 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import rotary_emb
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
||||
from torch import Tensor, nn
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
from .utils import gather_forward_split_backward
|
||||
|
||||
|
||||
class Embedding1D(nn.Module):
|
||||
"""
|
||||
1D Embedding.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): The size of vocab.
|
||||
embedding_dim (int): The dimention of model.
|
||||
padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
|
||||
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
|
||||
i.e. it remains as a fixed "pad". None by default.
|
||||
dtype (Optional[torch.dtype]): Data type None by default.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
*args,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
"""
|
||||
ApplyRotaryEmbQKV_
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
||||
"""
|
||||
qkv: (total, 3, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
||||
"""
|
||||
_, three, _, headdim = qkv.shape
|
||||
assert three == 3
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
|
||||
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(
|
||||
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
|
||||
)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dqkv):
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(
|
||||
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
|
||||
)
|
||||
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(
|
||||
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
|
||||
)
|
||||
return dqkv, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
A crucial insight from the method is that the query and keys are
|
||||
transformed by rotation matrices which depend on the relative positions.
|
||||
|
||||
Other implementations are available in the Rotary Transformer repo_ and in
|
||||
GPT-NeoX_, GPT-NeoX was an inspiration
|
||||
|
||||
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
||||
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
||||
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
||||
|
||||
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
||||
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base=10000, scale_base=0, device=None):
|
||||
""" """
|
||||
super().__init__()
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.scale_base = scale_base
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base > 0
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, x, indexes):
|
||||
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
|
||||
if not isinstance(indexes, int):
|
||||
seqlen = indexes.max().item() + 1
|
||||
else:
|
||||
seqlen = indexes + 1 # eval_forward
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._update_cos_sin_cache(qkv, indexes)
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached[indexes],
|
||||
self._sin_cached[indexes],
|
||||
self._cos_k_cached[indexes],
|
||||
self._sin_k_cached[indexes],
|
||||
)
|
||||
|
||||
def eval_forward(self, qkv, seqlen_offset=0):
|
||||
"""
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
"""
|
||||
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
|
||||
if self.scale is None:
|
||||
return legacy_apply_rotary_embed_qkv(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
|
||||
)
|
||||
else:
|
||||
return legacy_apply_rotary_embed_qkv(
|
||||
qkv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self._cos_k_cached[seqlen_offset:],
|
||||
self._sin_k_cached[seqlen_offset:],
|
||||
)
|
|
@ -0,0 +1,176 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.ops.fused_dense import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
fused_dense_func,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(nn.Linear):
|
||||
"""
|
||||
ScaleColumnParallelLinear.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul.
|
||||
If not, then the input is already gathered.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
weight_scale (int): For training stability. 1 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
bias: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if out_features % world_size != 0:
|
||||
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
||||
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward(self, input): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
if self.weight_scale != 1:
|
||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
)
|
||||
|
||||
|
||||
class RewardModelLinear(ScaleColumnParallelLinear):
|
||||
"""
|
||||
RewardModelLinear.
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul.
|
||||
If not, then the input is already gathered.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
weight_scale (int): For training stability. 1 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
bias: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale)
|
||||
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
if bias:
|
||||
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
|
||||
def forward(self, input): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
if self.weight_scale != 1:
|
||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
FeedForward.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
hidden_features (int): size of hidden state of FFN
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: int = None,
|
||||
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
bias: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = ColumnParallelLinear(
|
||||
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
|
||||
)
|
||||
self.w3 = RowParallelLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# need to assign tp attribute so that colossalai know it is tensor parallel module
|
||||
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for name in ["w1", "w2", "w3"]:
|
||||
for param in getattr(self, name).parameters():
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
|
||||
return out
|
|
@ -0,0 +1,54 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
class FlashGPTLMLoss(nn.Module):
|
||||
"""
|
||||
Loss function for flash GPT Language Model.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_output=True, label_smoothing=0):
|
||||
super().__init__()
|
||||
|
||||
if label_smoothing is not None:
|
||||
if label_smoothing != 0:
|
||||
if gpc.is_rank_for_log():
|
||||
print(f"use label_smoothing: {label_smoothing}")
|
||||
else:
|
||||
label_smoothing = 0
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
if parallel_output:
|
||||
self.loss_fn = FlashCrossEntropyLoss(
|
||||
reduction="mean",
|
||||
inplace_backward=True,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
label_smoothing=label_smoothing,
|
||||
) # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D
|
||||
else:
|
||||
# Here, the output will gather output is set in the model, so use ordinary loss
|
||||
self.loss_fn = nn.CrossEntropyLoss(reduction="mean", label_smoothing=label_smoothing)
|
||||
|
||||
def forward(self, *args):
|
||||
if len(args) == 3:
|
||||
# residual is to match prenorm
|
||||
logits, _, labels = args
|
||||
elif len(args) == 2:
|
||||
# When using postnorm
|
||||
logits, labels = args
|
||||
else:
|
||||
raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
|
||||
shift_logits = logits.contiguous().view(-1, logits.size(-1))
|
||||
shift_labels = labels.contiguous().view(-1)
|
||||
loss = self.loss_fn(
|
||||
shift_logits, shift_labels
|
||||
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
|
||||
# calculated through the calculation range, and -100 must be outside this range, so there is no problem
|
||||
|
||||
return loss
|
|
@ -0,0 +1,511 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
MODEL_TYPE = "INTERNLM"
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class PackedFlashBaseLayer1D(nn.Module):
|
||||
"""
|
||||
1D Packed Flash Base Layer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): The hidden size of model. 768 by default.
|
||||
num_attention_heads (int): The number of attention heads. 12 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
||||
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
|
||||
dtype (torch.dtype): Type of data. torch.float by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
||||
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
|
||||
layer_idx (int): The index of current layer. 0 by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
layer_norm_epsilon: float = 1e-6,
|
||||
checkpoint: bool = False,
|
||||
layer_idx: int = 0,
|
||||
residual_in_fp32: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
norm_type: str = "rmsnorm",
|
||||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
||||
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
self.mixer = MHA(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=num_attention_heads,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
dropout=attn_drop_rate,
|
||||
softmax_scale=1 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
layer_idx=layer_idx,
|
||||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=True,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
if norm_type == "rmsnorm":
|
||||
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=False,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.dropout2 = nn.Dropout(drop_rate)
|
||||
self.use_swiglu = use_swiglu
|
||||
self.use_scaled_init = use_scaled_init
|
||||
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
||||
self.return_residual = False
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
with torch.no_grad():
|
||||
for name, param in self.mixer.named_parameters():
|
||||
if param.ndim == 1:
|
||||
param.data.zero_()
|
||||
elif "Wqkv" in name:
|
||||
normal_(std=0.006)(param.data)
|
||||
elif self.use_scaled_init:
|
||||
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
||||
else:
|
||||
normal_(std=0.0015)(param.data)
|
||||
|
||||
for name, param in self.mlp.named_parameters():
|
||||
if param.ndim == 1 and "bias" in name:
|
||||
param.data.zero_()
|
||||
elif self.use_swiglu:
|
||||
if self.use_scaled_init and "w2" in name:
|
||||
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
||||
else:
|
||||
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
|
||||
else:
|
||||
if self.use_scaled_init and "fc1" not in name:
|
||||
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
||||
else:
|
||||
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
||||
if self.checkpoint and self.training:
|
||||
return activation_checkpoint(
|
||||
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
||||
)
|
||||
else:
|
||||
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
||||
|
||||
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: hidden_states = Attn/MLP(LN(residual))
|
||||
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
|
||||
indexes: the length of index is same as hidden states, which stand for the current position
|
||||
"""
|
||||
mixer_kwargs = {
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": max_seqlen,
|
||||
"indexes": indexes,
|
||||
"inference_params": inference_params,
|
||||
}
|
||||
|
||||
def _dropout_and_norm_attn(_hidden_states):
|
||||
_dropped = self.dropout1(_hidden_states)
|
||||
_residual = _dropped
|
||||
_hidden_states = self.norm1(_residual.float())
|
||||
return _residual, _hidden_states
|
||||
|
||||
if self.dropout_selective_checkpoint:
|
||||
residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
|
||||
else:
|
||||
residual, hidden_states = _dropout_and_norm_attn(hidden_states)
|
||||
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||
|
||||
def _dropout_and_norm_ffn(_residual, _hidden_states):
|
||||
_dropped = self.dropout2(_hidden_states)
|
||||
_residual = (_dropped + _residual) if _residual is not None else _dropped
|
||||
_hidden_states = self.norm2(_residual.float())
|
||||
return _residual, _hidden_states
|
||||
|
||||
if self.dropout_selective_checkpoint:
|
||||
residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
|
||||
else:
|
||||
residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
|
||||
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class PackedFlashInternLm1D(nn.Module):
|
||||
"""
|
||||
1D Packed Flash InternLm.
|
||||
|
||||
Args:
|
||||
num_layers (int): The number of layer. 12 by default.
|
||||
hidden_size (int): The size of hidden state. 768 by default.
|
||||
num_attention_heads (int): The number of attention head. 12 by default.
|
||||
vocab_size (int): The size of vocabulary. 50304 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
||||
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
|
||||
dtype (torch.dtype): The type of data. torch.float by default.
|
||||
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
|
||||
checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
|
||||
of layers. 1.0 by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
|
||||
first (bool): Whether input embedding layer or not. False by default.
|
||||
last (bool): Whether output embedding layer or not. False by default.
|
||||
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
||||
True by default.
|
||||
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
||||
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
||||
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
checkpoint_fraction: float = 1.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden: bool = False,
|
||||
embed_grad_scale: float = 0.1,
|
||||
parallel_output: bool = True,
|
||||
start_layer_idx: int = 0,
|
||||
device: Optional[torch.device] = None,
|
||||
residual_in_fp32: bool = False,
|
||||
norm_type: str = "rmsnorm",
|
||||
is_reward: bool = False,
|
||||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if checkpoint_fraction <= 0:
|
||||
checkpoint = False
|
||||
if not checkpoint:
|
||||
checkpoint_fraction = 0
|
||||
checkpoint_layer_num = num_layers * checkpoint_fraction
|
||||
if is_reward:
|
||||
head_cls = RewardModelLinear
|
||||
else:
|
||||
head_cls = ScaleColumnParallelLinear
|
||||
if first:
|
||||
if embed_split_hidden:
|
||||
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
||||
else:
|
||||
self.embedding = ParallelGPT2Embeddings(
|
||||
embed_dim=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.embedding.named_parameters():
|
||||
normal_(std=0.0052)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.embed_grad_scale = embed_grad_scale
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PackedFlashBaseLayer1D(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
dtype=dtype,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
checkpoint=lid < checkpoint_layer_num,
|
||||
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
device=device,
|
||||
norm_type=norm_type,
|
||||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
if norm_type == "rmsnorm":
|
||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.head = head_cls(
|
||||
in_features=hidden_size,
|
||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
)
|
||||
for _, param in self.head.named_parameters():
|
||||
normal_(std=0.0052)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
if hasattr(self, "embedding"):
|
||||
hidden_states = self.embedding(input_ids)
|
||||
if self.embed_grad_scale != 1:
|
||||
hidden_states = (
|
||||
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
|
||||
)
|
||||
if isinstance(cu_seqlens, list):
|
||||
assert len(cu_seqlens) == 1
|
||||
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
|
||||
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
|
||||
# the batch dimension with a size of 1 should be directly squeezed off.
|
||||
|
||||
if indexes is not None:
|
||||
assert len(indexes) == 1
|
||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
indexes=indexes,
|
||||
inference_params=inference_params,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
if hasattr(self, "head"):
|
||||
hidden_states = self.head(hidden_states)
|
||||
|
||||
if not self.parallel_output:
|
||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
"""
|
||||
build generic model 1d
|
||||
|
||||
Args:
|
||||
num_layers (int): The number of layer.
|
||||
num_chunks (int): The number of partitions in pipeline parallel.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
|
||||
|
||||
"""
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
# all_parts = partition_uniform_with_embed2(num_layers, pipeline_size, num_chunks)
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
|
||||
models = []
|
||||
|
||||
if kwargs["checkpoint"] is True:
|
||||
kwargs["checkpoint_fraction"] = 1.0
|
||||
else:
|
||||
kwargs["checkpoint_fraction"] = 0
|
||||
|
||||
for start, end in parts:
|
||||
kwargs["num_layers"] = end - start
|
||||
kwargs["first"] = start == 0
|
||||
# If there is no content in the final layer, assign the last layer.
|
||||
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
|
||||
kwargs["device"] = device
|
||||
kwargs["start_layer_idx"] = start
|
||||
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
|
||||
|
||||
models.append(chunk)
|
||||
torch.distributed.barrier()
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
|
||||
def build_model_with_cfg(
|
||||
num_chunks=1,
|
||||
checkpoint=False,
|
||||
dtype=torch.float,
|
||||
embed_split_hidden=False,
|
||||
num_layers=48,
|
||||
hidden_size=2048,
|
||||
vocab_size=50304,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
num_attention_heads=32,
|
||||
mlp_ratio=4.0,
|
||||
residual_in_fp32=False,
|
||||
norm_type="rmsnorm",
|
||||
drop_rate=0,
|
||||
attn_drop_rate=0,
|
||||
apply_post_layer_norm=False, # pylint: disable=W0613
|
||||
layer_norm_epsilon=1e-5,
|
||||
is_reward=False,
|
||||
dropout_selective_checkpoint=True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
||||
Args:
|
||||
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
|
||||
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
|
||||
dtype (torch.dtype): The type of data. torch.float by default.
|
||||
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
||||
False by default.
|
||||
num_layers (int): The number of layer. 48 by default.
|
||||
hidden_size (int): The size of hidden state. 2048 by default.
|
||||
vocab_size (int): The size of vocabulary. 50304 by default.
|
||||
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
||||
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
||||
num_attention_heads (int): The number of attention head. 32 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
|
||||
because this parameter requires inconsistent data types to be passed between pipelines,
|
||||
which requires significant modifications to internlm.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
drop_rate (float): The dropout rate of input hidden state. 0 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
||||
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
||||
is_reward (bool): Whether to use reward model. False by default.
|
||||
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
||||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
|
||||
"""
|
||||
|
||||
cfg = dict(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
vocab_size=vocab_size,
|
||||
embed_grad_scale=embed_grad_scale,
|
||||
parallel_output=parallel_output,
|
||||
mlp_ratio=mlp_ratio,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
norm_type=norm_type,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
is_reward=is_reward,
|
||||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
"""
|
||||
Multi-head self-attention and cross-attention.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The dimention of hidden state.
|
||||
num_heads (int): The number of attention heads.
|
||||
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
|
||||
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
|
||||
output projection. True by default.
|
||||
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
|
||||
softmax_scale (float): The temperature to use for the softmax attention.
|
||||
causal (boolean): Whether to apply causal attention mask. False by default.
|
||||
layer_idx (int): The index of current layer. None by default.
|
||||
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
|
||||
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
|
||||
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
|
||||
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
|
||||
False by default.
|
||||
sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw
|
||||
of x will be done before doing the matmul.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
dropout: float = 0.0,
|
||||
softmax_scale: float = None,
|
||||
causal: bool = False,
|
||||
layer_idx: int = None,
|
||||
rotary_emb_dim: int = 0,
|
||||
rotary_emb_scale_base: int = 0,
|
||||
use_flash_attn: bool = False,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.num_heads = num_heads
|
||||
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
|
||||
# notice here should change bias=True
|
||||
self.Wqkv = ColumnParallelLinear(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=True,
|
||||
sequence_parallel=sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
||||
self.inner_cross_attn = inner_cross_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for name in ["out_proj", "Wqkv"]:
|
||||
for param in getattr(self, name).parameters():
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||
if kwargs.get("indexes", None) is not None:
|
||||
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
||||
else:
|
||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params)
|
||||
|
||||
def _forward(self, x, seqlen=None, inference_params=None):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
||||
split x during sequence parallel, we split the batch * seqlen dimension
|
||||
(in case batch is small).
|
||||
"""
|
||||
qkv = self.Wqkv(x)
|
||||
if seqlen is None:
|
||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||
else:
|
||||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
if inference_params is None:
|
||||
qkv = self.rotary_emb.eval_forward(qkv)
|
||||
else:
|
||||
qkv = self.rotary_emb.eval_forward(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv)
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
|
||||
if seqlen is None:
|
||||
context = rearrange(context, "b s h d -> b s (h d)")
|
||||
else:
|
||||
context = rearrange(context, "b s h d -> (b s) (h d)")
|
||||
|
||||
out = self.out_proj(context)
|
||||
return out
|
||||
|
||||
def _packed_forward(self, x, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
||||
split x during sequence parallel, we split the batch * seqlen dimension
|
||||
(in case batch is small).
|
||||
"""
|
||||
qkv = self.Wqkv(x) # total x hsz'
|
||||
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
||||
qkv = self.rotary_emb(qkv, kwargs.pop("indexes"))
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
else:
|
||||
raise RuntimeError("Not support this right now")
|
||||
|
||||
context = rearrange(context, "b h d -> b (h d)") # recover the shape
|
||||
out = self.out_proj(context)
|
||||
return out
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
def _split(input_, parallel_mode, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, (
|
||||
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
|
||||
f"cannot split tensor evenly"
|
||||
)
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, parallel_mode, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(input_):
|
||||
return _gather(input_, parallel_mode=None)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _gather(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .beta2_scheduler import Beta2Scheduler
|
||||
from .lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from .optimizer import HybridZeroOptimizer
|
||||
|
||||
__all__ = ["Beta2Scheduler", "FineTuneCosineAnnealingWarmupLR", "HybridZeroOptimizer"]
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Beta2Scheduler:
|
||||
"""
|
||||
Beta2Scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: torch.optim.Adam, init_beta2, c=0.8, cur_iter=-1):
|
||||
self.cur_iter = 0 if cur_iter == -1 else cur_iter
|
||||
self.init_beta2 = init_beta2
|
||||
self.c = c
|
||||
self.optimizer = optimizer
|
||||
assert isinstance(
|
||||
optimizer, (torch.optim.Adam, torch.optim.AdamW)
|
||||
), "should use Adam optimzier, which has beta2"
|
||||
|
||||
def step(self, cur_iter=None):
|
||||
if cur_iter is None:
|
||||
self.cur_iter += 1
|
||||
else:
|
||||
self.cur_iter = cur_iter
|
||||
|
||||
new_beta2 = self.get_beta2()
|
||||
for pg in self.optimizer.param_groups:
|
||||
beta1, _ = pg["betas"]
|
||||
pg["betas"] = (beta1, new_beta2)
|
||||
|
||||
def get_beta2(self):
|
||||
if self.c <= 0:
|
||||
return self.init_beta2
|
||||
scale = 1 - (1 / self.cur_iter**self.c)
|
||||
return max(self.init_beta2, scale)
|
|
@ -0,0 +1,135 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import json
|
||||
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class WarmupScheduler(_LRScheduler):
|
||||
"""Starts with a linear warmup lr schedule until it reaches N epochs then applies
|
||||
the specific scheduler (For example: ReduceLROnPlateau).
|
||||
|
||||
Args:
|
||||
optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer.
|
||||
warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler.
|
||||
after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler.
|
||||
last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1,
|
||||
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||
self.warmup_epochs = int(warmup_epochs)
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
# state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
for key in list(self.__dict__.keys()):
|
||||
if key in state_dict:
|
||||
self.__dict__[key] = state_dict[key]
|
||||
if isinstance(self.after_scheduler, _LRScheduler):
|
||||
assert type(self.after_scheduler).__name__ == state_dict["after_scheduler_type"]
|
||||
# state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"])
|
||||
# del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
self.finished = True
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.warmup_epochs)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
return super().step(epoch)
|
||||
|
||||
|
||||
class CosineAnnealingWarmupLR(WarmupScheduler):
|
||||
"""Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied.
|
||||
|
||||
Args:
|
||||
optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer.
|
||||
total_steps (int): Number of total training steps.
|
||||
warmup_steps (int, optional): Number of warmup steps, defaults to 0.
|
||||
eta_min (int, optional): Minimum learning rate, defaults to 0.
|
||||
last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1,
|
||||
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0.0, last_epoch: int = -1):
|
||||
base_scheduler = _CosineAnnealingLR(
|
||||
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch
|
||||
)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler)
|
||||
|
||||
|
||||
class FineTuneCosineAnnealingWarmupLR(CosineAnnealingWarmupLR):
|
||||
"""
|
||||
FineTune Cosine Annealing Warmup LR.
|
||||
|
||||
Args:
|
||||
optimizer: The optimizer object.
|
||||
total_steps (int): The number of total steps.
|
||||
init_steps (int): The number of init steps, default is 0.
|
||||
warmup_steps (int): The number of warm up steps, default is 0.
|
||||
eta_min (float): The minimum learning rate, default is 0.0.
|
||||
last_epoch: Last epoch, default is -1.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
total_steps: int,
|
||||
init_steps: int = 0,
|
||||
warmup_ratio: float = 0.0,
|
||||
eta_min: float = 0.0,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self._init_steps = init_steps
|
||||
self._warmup_steps = int(total_steps * warmup_ratio)
|
||||
# Use this value to calculate the lr of warmup, because warmup_epochs = init_steps + warmup_steps
|
||||
super().__init__(optimizer, total_steps, self._warmup_steps + init_steps, eta_min, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
if not self.finished: # pylint: disable=E0203
|
||||
# This True switch is to avoid warning when the warmup reaches the preset value switch
|
||||
self.after_scheduler._get_lr_called_within_step = True
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
self.finished = True
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
elif self.last_epoch >= self._init_steps:
|
||||
return [(self.last_epoch + 1 - self._init_steps) / self._warmup_steps * lr for lr in self.base_lrs]
|
||||
else:
|
||||
return [0 for lr in self.base_lrs]
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.state_dict(), indent=4, sort_keys=True)
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer
|
||||
|
||||
__all__ = ["HybridZeroOptimizer"]
|
|
@ -0,0 +1,818 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import amp_C
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
from torch._six import inf
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
GradientStore,
|
||||
ParameterStore,
|
||||
TensorBucket,
|
||||
)
|
||||
from internlm.solver.optimizer.utils import (
|
||||
DynamicGradScaler,
|
||||
flatten,
|
||||
get_grad_accumulate_object,
|
||||
has_inf_or_nan,
|
||||
reduce_tensor,
|
||||
release_param_grad,
|
||||
split_half_float_double,
|
||||
sync_param,
|
||||
)
|
||||
from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
class BaseOptimizer(Optimizer):
|
||||
"""
|
||||
Base Optimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, optim: Optimizer): # pylint: disable=W0231
|
||||
self.optim = optim
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self.optim.param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self.optim.defaults
|
||||
|
||||
def add_param_group(self, *args, **kwargs):
|
||||
return self.optim.add_param_group(*args, **kwargs)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
return self.optim.step(*args, **kwargs)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.optim.zero_grad(*args, **kwargs)
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
self.optim.load_state_dict(*args, **kwargs)
|
||||
|
||||
def state_dict(self):
|
||||
return self.optim.state_dict()
|
||||
|
||||
def backward(self, loss):
|
||||
loss.backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
|
||||
def clip_grad_norm(self):
|
||||
pass
|
||||
|
||||
|
||||
class HybridZeroOptimizer(BaseOptimizer):
|
||||
"""
|
||||
Hybrid Zero Optimizer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
cpu_offload=False,
|
||||
overlap_broadcast=False,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
):
|
||||
# DynamicGradScaler related args
|
||||
initial_scale = grad_scal_cfg.fp16.initial_scale
|
||||
min_scale = grad_scal_cfg.fp16.min_scale
|
||||
growth_interval = grad_scal_cfg.fp16.growth_interval
|
||||
growth_factor = grad_scal_cfg.growth_factor
|
||||
backoff_factor = grad_scal_cfg.backoff_factor
|
||||
hysteresis = grad_scal_cfg.hysteresis
|
||||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
# Zero related args
|
||||
overlap_communication = zero_cfg.zero_overlap_communication
|
||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||
clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
|
||||
super().__init__(optim=optimizer)
|
||||
|
||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||
self._cpu_offload = cpu_offload
|
||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
self._broadcast_parallel_mode = ParallelMode.ZERO1
|
||||
|
||||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._bucket_store = BucketStore(ParallelMode.DATA)
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device())
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
# need to record the rank in which parameter groups are not assigned parameters.
|
||||
self.param_group_has_params = []
|
||||
self.param_group_no_params_ranks = []
|
||||
self.padding_grad = torch.zeros([32], dtype=self._dtype, device=get_current_device())
|
||||
self.padding_tensor = torch.zeros([32], dtype=self._dtype, device=get_current_device())
|
||||
|
||||
self.rank_unique_id = (
|
||||
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
||||
+ f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
||||
+ f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||
+ f"zo-{self._zero_local_rank}.pt"
|
||||
)
|
||||
self.params_per_rank_id_dict = []
|
||||
self.overlap_broadcast = overlap_broadcast
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||
group_params = param_group["params"]
|
||||
|
||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
self._fp16_param_groups[group_id] = group_params
|
||||
|
||||
# assign parameters to ranks the params in the list are sorted
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(group_params)
|
||||
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
||||
|
||||
# store the mapping between param to rank each param should belong to only one rank
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
# check whether any rank is not assigned params.
|
||||
if len(params) != 0:
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
for param in group_params:
|
||||
param.data = param.data.cpu()
|
||||
|
||||
# flatten the reordered tensors
|
||||
for rank in range(self._zero_world_size):
|
||||
# No flat fp16 buffer is allocated if the process has no parameters.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
with torch.no_grad():
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.data.cuda()
|
||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
||||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||
if self.param_group_has_params[group_id]:
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
self._zero_local_rank, group_id
|
||||
)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||
fp32_flat_current_rank.requires_grad = True
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
||||
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group["params"] = [fp32_flat_current_rank]
|
||||
|
||||
# set reduction state
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
self._param_store.set_param_reduction_state(param, False)
|
||||
|
||||
assert len(self._fp16_param_groups) != 0
|
||||
|
||||
# If a rank is not assigned any arguments, 'has_params' is False.
|
||||
self.has_params = sum(self.param_group_has_params) != 0
|
||||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||
self.skip_grad_reduce = False
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_communication:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
@property
|
||||
def zero_local_rank(self):
|
||||
return self._zero_local_rank
|
||||
|
||||
@property
|
||||
def zero_world_size(self):
|
||||
return self._zero_world_size
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
@property
|
||||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
no_params_ranks = []
|
||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
||||
|
||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||
for i, param in enumerate(sorted_params):
|
||||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
||||
# check whether any rank is not assigned to parameters.
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
if len(params) == 0:
|
||||
no_params_ranks.append(rank)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}")
|
||||
|
||||
return params_per_rank, set(no_params_ranks)
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
# on each param, we register a hook to its AccumulateGrad object
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad:
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank=None):
|
||||
# get the AccumulateGrad object of the param itself
|
||||
# If these objects are not kept, reduction hooks may not be attached successfully.
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
reduction_func = partial(
|
||||
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
|
||||
)
|
||||
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||
if self.skip_grad_reduce is False:
|
||||
reduction_func()
|
||||
|
||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
|
||||
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_stored_in_bucket(reduce_rank)
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = (
|
||||
f"Parameter of size ({param.size()}) has already been reduced, "
|
||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f"Parameter of size ({param.size()}) has None grad, cannot be reduced"
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
def _reduce_grads_stored_in_bucket(self, reduce_rank=None):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(
|
||||
reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
|
||||
)
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
||||
if is_param_reduced:
|
||||
msg = (
|
||||
f"Parameter of size ({param.size()}) has been reduced, "
|
||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
|
||||
def _has_inf_or_nan(self, tensor):
|
||||
try:
|
||||
tensor_mean = float(tensor.mean())
|
||||
except RuntimeError as instance:
|
||||
# We want to check if inst is actually an overflow exception.
|
||||
# RuntimeError could come from a different error.
|
||||
# If so, we still want the exception to propagate.
|
||||
if "value cannot be converted" not in instance.args[0]:
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_mean == float("inf") or tensor_mean == -float("inf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _sync_grad(self):
|
||||
# update param already reduced flag
|
||||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, _ in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
|
||||
# accumulate gradient
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
for group_id in range(self.num_param_groups):
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
|
||||
|
||||
if group_id not in avg_gradients:
|
||||
avg_gradients[group_id] = []
|
||||
|
||||
param_idx = 0
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients[group_id]) == param_idx:
|
||||
avg_gradients[group_id].append(param.grad)
|
||||
else:
|
||||
avg_gradients[group_id][param_idx].add_(param.grad)
|
||||
param_idx += 1
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
# thus, can clear this
|
||||
self.zero_grad()
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
||||
will be set to None to save memory.
|
||||
|
||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
for _, param_group in self._fp16_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
param.grad = None
|
||||
elif param.grad is not None:
|
||||
param.grad.detach()
|
||||
param.grad.zero_()
|
||||
else:
|
||||
pass
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
loss = self.loss_scale * loss
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
# Gradients may not be fully synchronized here.
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (Callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
Returns:
|
||||
Union[bool, float]: Whether the gradient is success updated, and the gradient.
|
||||
"""
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
timer("sync_grad").start()
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
if param.grad is not None:
|
||||
self._store_and_try_reduce_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients left in the communication bucket
|
||||
self._reduce_grads_stored_in_bucket()
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
return self._step(closure=closure)
|
||||
|
||||
def _step(self, closure=None):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
# Because you may encounter inf when computing norm
|
||||
timer("cal_norm").start()
|
||||
norm_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
parameters = self._param_store.get_fp16_params_by_rank_group(
|
||||
group_id=group_id, rank=self._zero_local_rank
|
||||
)
|
||||
else:
|
||||
# in order to prevent collection communication from hanging,
|
||||
# we need to involve rank that are not assigned parameters in compute_norm(),
|
||||
# so we give them a fp16 vector of 0 values.
|
||||
gradients = [self.padding_grad]
|
||||
parameters = [self.padding_tensor]
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
# this norm is before scaling, it will be very large
|
||||
norm_group = compute_norm(
|
||||
gradients=gradients,
|
||||
parameters=parameters,
|
||||
)
|
||||
if norm_group == -1:
|
||||
timer("cal_norm").stop()
|
||||
found_inf = True
|
||||
break
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
loss_scale = float(self.loss_scale.item()) # backup
|
||||
self.grad_scaler.update(found_inf)
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
global_norm = 0
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if not self.param_group_has_params[group_id]:
|
||||
continue
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
fp16_avg_grads = gradients
|
||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
||||
|
||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||
|
||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert (
|
||||
param_shape == flat_fp32_avg_grads.shape
|
||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
||||
|
||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
if self._clip_grad_norm > 0:
|
||||
global_norm = sum(norm_groups) ** 0.5
|
||||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if len(single_grad_partition_groups) != 0:
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
|
||||
|
||||
timer("cal_norm").stop()
|
||||
# update the parameters
|
||||
timer("step").start()
|
||||
|
||||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||
# to send them updated their own parameters.
|
||||
if self.has_params:
|
||||
self.optim.step()
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self.param_group_has_params[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=group_id
|
||||
)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
# TODO: support broadcast overlap
|
||||
self.broadcast_params(overlap=False)
|
||||
|
||||
timer("step").stop()
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
return True, global_norm / loss_scale
|
||||
|
||||
def broadcast_params(self, overlap=False):
|
||||
handles = []
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._zero_world_size):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
||||
)
|
||||
handles.append(handle)
|
||||
|
||||
if not overlap:
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
else:
|
||||
return handles
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
##################
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = loss_scale
|
||||
|
||||
if self._clip_grad_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1.0:
|
||||
combined_scale = clip * loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1.0 / combined_scale)
|
||||
|
||||
def clip_grad_norm(self, model, max_norm):
|
||||
# will conduct in the step()
|
||||
pass
|
||||
|
||||
def state_dict(self):
|
||||
states = {}
|
||||
grad_scaler = self.grad_scaler.state_dict()
|
||||
states["grad_scaler"] = grad_scaler
|
||||
optim_states = self.optim.state_dict()
|
||||
states["base_optim_states"] = optim_states
|
||||
|
||||
flat_fp32_weights = {}
|
||||
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
assert param.grad is None
|
||||
flat_fp32_weights[group_id] = param
|
||||
states["flat_fp32_weights"] = flat_fp32_weights
|
||||
states["zero_devide_optim_plan"] = self.params_per_rank_id_dict
|
||||
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states):
|
||||
# TODO: Need to take into account the change in the number of DP.
|
||||
assert "grad_scaler" in states, "Not found grad_scaler state!"
|
||||
grad_scaler = states["grad_scaler"]
|
||||
self.grad_scaler.load_state_dict(grad_scaler)
|
||||
optim_states = states["base_optim_states"]
|
||||
self.optim.load_state_dict(optim_states)
|
||||
|
||||
# load fp32 model weight.
|
||||
flat_fp32_weights = states["flat_fp32_weights"]
|
||||
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
||||
for group_id, param in flat_fp32_weights.items():
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
assert (
|
||||
self_param.shape == param.shape
|
||||
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
|
||||
self_param.data.copy_(param.data)
|
||||
|
||||
# Load the fp16 model weights.
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=group_id
|
||||
)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
if "zero_devide_optim_plan" in states:
|
||||
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, norm_type=2):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters, need total_norm**(1/norm) before using.
|
||||
"""
|
||||
|
||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
for g, p in zip(gradients, parameters):
|
||||
# TODO: consider the pipeline shared parameter
|
||||
if (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) == 0
|
||||
): # if shared between different pipe, only count o
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) != 0
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.TENSOR)
|
||||
and not is_model_parallel_parameter(p)
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
): # if not used in each chunk, such as layernorm
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif is_model_parallel_parameter(p):
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Should not arrive here")
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
|
||||
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
|
||||
|
||||
total_norm = tensor_parallel_norm
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
|
@ -0,0 +1,284 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
class BaseStore:
|
||||
"""
|
||||
Base Store
|
||||
"""
|
||||
|
||||
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
return self._local_rank
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
"""
|
||||
Bucket Store
|
||||
"""
|
||||
|
||||
def __init__(self, dp_parallel_mode):
|
||||
super().__init__(dp_parallel_mode)
|
||||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
|
||||
def add_grad(self, tensor, reduce_rank: int = None):
|
||||
self._grads[reduce_rank].append(tensor)
|
||||
|
||||
def add_param(self, tensor, reduce_rank: int = None):
|
||||
self._params[reduce_rank].append(tensor)
|
||||
|
||||
def reset(self):
|
||||
keys = [None] + list(range(self._world_size))
|
||||
self._grads = {rank: [] for rank in keys}
|
||||
self._params = {rank: [] for rank in keys}
|
||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
||||
|
||||
def reset_by_rank(self, reduce_rank=None):
|
||||
self._grads[reduce_rank] = []
|
||||
self._params[reduce_rank] = []
|
||||
self._num_elements_in_bucket[reduce_rank] = 0
|
||||
|
||||
def get_grad(self, reduce_rank: int = None):
|
||||
return self._grads[reduce_rank]
|
||||
|
||||
def get_param(self, reduce_rank: int = None):
|
||||
return self._params[reduce_rank]
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
"""
|
||||
Gradient Store
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
# for backward reduction hooks
|
||||
self._grad_acc_objs = []
|
||||
|
||||
def add_accumulate_grad_object(self, obj):
|
||||
"""
|
||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||||
be attached successfully.
|
||||
|
||||
:param obj: An object of :class:`AccumulateGrad` class
|
||||
:type obj: :class:`AccumulateGrad`
|
||||
"""
|
||||
|
||||
self._grad_acc_objs.append(obj)
|
||||
|
||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||||
"""
|
||||
Return average gradients of a parameter group
|
||||
|
||||
:param group_id: The index of parameter group
|
||||
:type group_id: int
|
||||
|
||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient,
|
||||
not a parameter.
|
||||
:rtype: List[torch.Tensor]
|
||||
"""
|
||||
|
||||
return self._averaged_gradients[group_id]
|
||||
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Append an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
|
||||
if group_id in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id].append(tensor)
|
||||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
|
||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:type group_id: int
|
||||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
"""
|
||||
Parameter Store
|
||||
"""
|
||||
|
||||
def __init__(self, dp_paralle_mode):
|
||||
super().__init__(dp_paralle_mode)
|
||||
# param partitioning data structures
|
||||
self._fp16_param_to_rank = dict()
|
||||
self._rank_groupid_to_fp16_param_list = dict()
|
||||
self._rank_group_id_to_flat_fp16_param = dict()
|
||||
|
||||
# param reduction data structures
|
||||
self._is_param_reduced = dict()
|
||||
self._reduced_param = []
|
||||
|
||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||
"""
|
||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
:param rank: The rank of which the process is responsible for updating the parameter
|
||||
:type rank: int
|
||||
"""
|
||||
|
||||
self._fp16_param_to_rank[tensor] = rank
|
||||
|
||||
def get_param_rank(self, tensor: Tensor) -> int:
|
||||
"""
|
||||
Gives the rank which the parameter belongs to
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
"""
|
||||
return self._fp16_param_to_rank[tensor]
|
||||
|
||||
def belongs_to_current_rank(self, tensor) -> bool:
|
||||
"""
|
||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
tensor_rank = self._fp16_param_to_rank[tensor]
|
||||
return tensor_rank == self._local_rank
|
||||
|
||||
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
||||
|
||||
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
|
||||
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
|
||||
|
||||
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
|
||||
return self._rank_groupid_to_fp16_param_list[rank][group_id]
|
||||
|
||||
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
||||
if rank not in self._rank_group_id_to_flat_fp16_param:
|
||||
self._rank_group_id_to_flat_fp16_param[rank] = dict()
|
||||
|
||||
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
|
||||
|
||||
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
|
||||
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
|
||||
|
||||
def is_param_reduced(self, tensor):
|
||||
return self._is_param_reduced[tensor]
|
||||
|
||||
def set_param_reduction_state(self, tensor, state):
|
||||
self._is_param_reduced[tensor] = state
|
||||
|
||||
def get_param_reduction_states(self):
|
||||
return self._is_param_reduced
|
||||
|
||||
def reset_previous_reduced_params(self):
|
||||
self._reduced_param = []
|
||||
|
||||
def add_previous_reduced_param(self, tensor):
|
||||
self._reduced_param.append(tensor)
|
||||
|
||||
def clear_grads_of_previous_reduced_params(self):
|
||||
if len(self._reduced_param) > 0:
|
||||
for param in self._reduced_param:
|
||||
param.grad = None
|
||||
self.reset_previous_reduced_params()
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
"""
|
||||
Tensor Bucket
|
||||
"""
|
||||
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
return self._max_size
|
||||
|
||||
@property
|
||||
def current_size(self):
|
||||
return self._current_size
|
||||
|
||||
def is_full_or_oversized(self):
|
||||
return self._current_size >= self._max_size
|
||||
|
||||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket.append(tensor)
|
||||
self._current_size += tensor_size
|
||||
|
||||
def will_exceed_max_size(self, tensor_size):
|
||||
expected_size = self._current_size + tensor_size
|
||||
return expected_size > self._max_size
|
||||
|
||||
def get_bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
|
@ -0,0 +1,315 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
||||
|
||||
def unflatten(flat, tensors):
|
||||
return _unflatten_dense_tensors(flat, tensors)
|
||||
|
||||
|
||||
def get_grad_accumulate_object(tensor):
|
||||
"""
|
||||
Return the AccumulateGrad of the input tensor
|
||||
"""
|
||||
|
||||
# grad_fn reference:
|
||||
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
|
||||
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
|
||||
#
|
||||
# `next_functions` will return the backward graph where
|
||||
# the first element is the AccumulateGrad of the leaf nodes.
|
||||
# we want to get the AccumulateGrad of the input tensor instead of the leaf
|
||||
# node in the whole computation graph.
|
||||
# Therefore, we call expand_as to create a dummy graph
|
||||
# where tensor_tmp and tensor indeed point to the same object.
|
||||
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
|
||||
tensor_tmp = tensor.expand_as(tensor)
|
||||
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
|
||||
return grad_acc_obj
|
||||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||
buckets = []
|
||||
for _, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
if bucket:
|
||||
buckets.append(bucket)
|
||||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
:param tensor: A tensor object to reduce/all-reduce
|
||||
:param dtype: The data type used in communication
|
||||
:param dst_rank: The source rank for reduce. If dst_rank is None,
|
||||
:param parallel_mode: Communication parallel mode
|
||||
all-reduce will be used instead of reduce. Default is None.
|
||||
|
||||
:type tensor: torch.Tensor
|
||||
:type dtype: torch.dtype, optional
|
||||
:type dst_rank: int, optional
|
||||
:type parallel_mode: ParallelMode, optional
|
||||
"""
|
||||
# use the original dtype
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
|
||||
# cast the data to specified dtype for reduce/all-reduce
|
||||
if tensor.dtype != dtype:
|
||||
tensor_to_reduce = tensor.to(dtype)
|
||||
else:
|
||||
tensor_to_reduce = tensor
|
||||
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
group = gpc.get_group(parallel_mode)
|
||||
tensor_to_reduce.div_(world_size)
|
||||
|
||||
# if rank is None, all reduce will be used
|
||||
# else, reduce is used
|
||||
use_all_reduce = dst_rank is None
|
||||
|
||||
if use_all_reduce:
|
||||
dist.all_reduce(tensor_to_reduce, group=group)
|
||||
else:
|
||||
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
||||
global_rank = ranks_in_group[dst_rank]
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
||||
|
||||
# recover the original dtype
|
||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
if use_all_reduce or dst_rank == local_rank:
|
||||
tensor.copy_(tensor_to_reduce)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def has_inf_or_nan(tensor):
|
||||
try:
|
||||
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
||||
# (which is true for some recent version of pytorch).
|
||||
tensor_sum = float(tensor.float().sum())
|
||||
# More efficient version that can be used if .sum() returns a Python scalar
|
||||
# tensor_sum = float(tensor.sum())
|
||||
except RuntimeError as instance:
|
||||
# We want to check if inst is actually an overflow exception.
|
||||
# RuntimeError could come from a different error.
|
||||
# If so, we still want the exception to propagate.
|
||||
if "value cannot be converted" not in instance.args[0]:
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_sum == float("inf") or tensor_sum == -float("inf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def release_param_grad(tensor_list):
|
||||
for tensor in tensor_list:
|
||||
tensor.grad = None
|
||||
|
||||
|
||||
def sync_param(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
|
||||
a new tensor is created. Thus, the flat tensor and original tensor list do not
|
||||
share the same memory space. This function will update the tensor list so that
|
||||
they point to the same value.
|
||||
|
||||
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit
|
||||
:param tensor_list: A list of tensors corresponding to the flattened tensor
|
||||
:type flat_tensor: torch.Tensor
|
||||
:type tensor_list: List[torch.Tensor]
|
||||
"""
|
||||
updated_params = unflatten(flat_tensor, tensor_list)
|
||||
|
||||
# update the tensor data
|
||||
for p, q in zip(tensor_list, updated_params):
|
||||
p.data = q.data
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
"""A base class for the gradient scaler.
|
||||
|
||||
Args:
|
||||
initial_scale (float): the initial loss scale
|
||||
"""
|
||||
|
||||
def __init__(self, initial_scale: float):
|
||||
assert initial_scale > 0
|
||||
self._scale = torch.cuda.FloatTensor([initial_scale])
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
"""Returns the loss scale."""
|
||||
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def inv_scale(self) -> Tensor:
|
||||
"""Returns the inverse of the loss scale."""
|
||||
|
||||
return self._scale.double().reciprocal().float()
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
"""Returns the states of the gradient scaler as a dict object."""
|
||||
|
||||
state_dict = dict()
|
||||
state_dict["scale"] = self.scale
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
"""Load the states of the gradient scaler from a dict object.
|
||||
|
||||
Args:
|
||||
state_dict (dict): the states of the gradient scaler
|
||||
"""
|
||||
|
||||
self._scale = state_dict["scale"]
|
||||
|
||||
@abstractmethod
|
||||
def update(self, overflow: bool) -> None:
|
||||
"""Update the loss scale.
|
||||
|
||||
Args:
|
||||
overflow (bool): whether overflow occurs
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DynamicGradScaler(BaseGradScaler):
|
||||
"""A gradient scaler which uses dynamic loss scale
|
||||
|
||||
Args:
|
||||
initial_scale (float): the initial loss scale, defaults to 2**16
|
||||
growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2
|
||||
backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||
growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000
|
||||
min_scale (float): the minimum loss scale, defaults to None
|
||||
max_scale (float): the maximum loss scale, defaults to None
|
||||
hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
min_scale: Optional[float] = None,
|
||||
max_scale: Optional[float] = None,
|
||||
hysteresis: int = 2,
|
||||
):
|
||||
super().__init__(initial_scale)
|
||||
if min_scale:
|
||||
self._min_scale = torch.cuda.FloatTensor([min_scale])
|
||||
else:
|
||||
self._min_scale = None
|
||||
|
||||
if max_scale:
|
||||
self._max_scale = torch.cuda.FloatTensor([max_scale])
|
||||
else:
|
||||
self._max_scale = None
|
||||
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._growth_step = 0
|
||||
self._hysteresis = hysteresis
|
||||
self._hysteresis_step = 0
|
||||
self._sanity_checks()
|
||||
|
||||
def _sanity_checks(self) -> None:
|
||||
"""Check if the arguments are correct."""
|
||||
|
||||
if self._min_scale:
|
||||
assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
|
||||
if self._max_scale:
|
||||
assert self._min_scale > 0, "The maximum gradient scale cannot be zero or negative"
|
||||
assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
|
||||
assert self._backoff_factor < 1 and self._backoff_factor > 0, "The backoff factor must be between 0 and 1"
|
||||
assert self._hysteresis >= 0, "The hysteresis cannot be negative"
|
||||
|
||||
def update(self, overflow: bool) -> None:
|
||||
"""Update the loss scale.
|
||||
|
||||
Args:
|
||||
overflow (bool): whether overflow occurs
|
||||
"""
|
||||
if overflow:
|
||||
self._hysteresis_step += 1
|
||||
self._growth_step = 0
|
||||
|
||||
if self._hysteresis_step >= self._hysteresis:
|
||||
self._backoff_scale()
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}")
|
||||
else:
|
||||
self._growth_step += 1
|
||||
if self._growth_step == self._growth_interval:
|
||||
self._growth_step = 0
|
||||
self._hysteresis_step = 0
|
||||
self._grow_scale()
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
f"No overflow for consecutive {self._growth_interval} steps, "
|
||||
f"the loss scale is adjusted to {self.scale.item()}",
|
||||
)
|
||||
|
||||
def _backoff_scale(self) -> None:
|
||||
"""Decrease the loss scale"""
|
||||
|
||||
self._scale = self._scale * self._backoff_factor
|
||||
if self._min_scale:
|
||||
self._scale = torch.max(self._scale, self._min_scale)
|
||||
|
||||
def _grow_scale(self) -> None:
|
||||
"""Increase the loss scale"""
|
||||
|
||||
self._scale = self._scale * self._growth_factor
|
||||
if self._max_scale:
|
||||
self._scale = torch.min(self._scale, self._max_scale)
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the states of the gradient scaler as a dict object."""
|
||||
|
||||
state_dict = dict()
|
||||
state_dict["_scale"] = self._scale.item()
|
||||
state_dict["_growth_step"] = self._growth_step
|
||||
state_dict["_hysteresis_step"] = self._hysteresis_step
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load the states of the gradient scaler from a dict object.
|
||||
|
||||
Args:
|
||||
state_dict (dict): the states of the gradient scaler
|
||||
"""
|
||||
|
||||
self._scale = self._scale.fill_(state_dict["_scale"])
|
||||
self._growth_step = state_dict["_growth_step"]
|
||||
self._hysteresis_step = state_dict["_hysteresis_step"]
|
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert (
|
||||
num_items % num_chunks == 0
|
||||
), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
partition_items = num_items // num_chunks
|
||||
for idx in range(num_chunks):
|
||||
base_idx = idx * partition_items
|
||||
chunk_size = partition_items // pipeline_parallel_size
|
||||
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||||
if chunk_size == 0:
|
||||
raise ValueError("Some nodes in Pipeline have no requests")
|
||||
|
||||
for p in range(pipeline_parallel_size):
|
||||
st = base_idx
|
||||
base_idx += chunk_size + (p >= left)
|
||||
parts[p].append((st, base_idx))
|
||||
|
||||
indexes = []
|
||||
for _parts in parts:
|
||||
for s, e in _parts:
|
||||
indexes.extend(list(range(s, e)))
|
||||
assert len(indexes) == len(set(indexes)), indexes # should have no duplicates
|
||||
assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected
|
||||
return parts
|
|
@ -0,0 +1,269 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||
|
||||
from internlm.core.context.random import (
|
||||
get_current_mode,
|
||||
get_states,
|
||||
set_mode,
|
||||
set_seed_states,
|
||||
sync_states,
|
||||
)
|
||||
|
||||
from .common import get_current_device
|
||||
|
||||
|
||||
def copy_to_device(obj, device):
|
||||
if torch.is_tensor(obj):
|
||||
# Notice:
|
||||
# When in no_grad context, requires_gard is False after movement
|
||||
ret = obj.to(device).detach()
|
||||
ret.requires_grad = obj.requires_grad
|
||||
return ret
|
||||
elif isinstance(obj, list):
|
||||
return [copy_to_device(i, device) for i in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple([copy_to_device(v, device) for v in obj])
|
||||
elif isinstance(obj, dict):
|
||||
return {k: copy_to_device(v, device) for k, v in obj.items()}
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
"""
|
||||
Checkpoint Function
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, activation_offload=False, *args): # pylint: disable=W1113
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.activation_offload = activation_offload
|
||||
ctx.device = get_current_device()
|
||||
|
||||
# preserve rng states
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
ctx.fwd_seed_states = get_states(copy=True)
|
||||
ctx.fwd_current_mode = get_current_mode()
|
||||
|
||||
if hasattr(torch, "is_autocast_enabled"):
|
||||
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
ctx.had_autocast_in_fwd = False
|
||||
|
||||
if activation_offload:
|
||||
inputs_cuda = copy_to_device(args, ctx.device)
|
||||
else:
|
||||
inputs_cuda = args
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*inputs_cuda)
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
if activation_offload:
|
||||
tensor_inputs.append(copy_to_device(arg, "cpu"))
|
||||
else:
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
if activation_offload:
|
||||
ctx.tensor_inputs = tensor_inputs
|
||||
else:
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
|
||||
if ctx.activation_offload:
|
||||
tensors = ctx.tensor_inputs
|
||||
else:
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# store the current states
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
bwd_seed_states = get_states(copy=True)
|
||||
bwd_current_mode = get_current_mode()
|
||||
|
||||
# set the states to what it used to be
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
for parallel_mode, state in ctx.fwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(ctx.fwd_current_mode)
|
||||
if ctx.activation_offload:
|
||||
tensors = copy_to_device(tensors, ctx.device)
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
if ctx.had_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
else:
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
# recover the rng states
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
for parallel_mode, state in bwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(bwd_current_mode)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)):
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def activation_checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
|
||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
|
||||
Args:
|
||||
function: Describe the forward pass function. It should know how to handle the input tuples.
|
||||
activation_offload: The variable to check whether we should offload activation to cpu
|
||||
args (list): Tuple containing the parameters of the function
|
||||
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
|
||||
might be more flexibility for user to define there checkpoint function
|
||||
Returns:
|
||||
Output of running function with provided args.
|
||||
"""
|
||||
if use_reentrant:
|
||||
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||
else:
|
||||
return _checkpoint_without_reentrant(
|
||||
function,
|
||||
activation_offload,
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def _checkpoint_without_reentrant(function, activation_offload=False, *args): # pylint: disable=W1113
|
||||
# store rng_state
|
||||
fwd_cpu_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
fwd_seed_states = get_states(copy=True)
|
||||
fwd_current_mode = get_current_mode()
|
||||
|
||||
# check if use autocast
|
||||
if hasattr(torch, "is_autocast_enabled"):
|
||||
has_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
has_autocast_in_fwd = False
|
||||
|
||||
# using WeakKeyDictionary to store all the activation the first time we call unpack
|
||||
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
weak_holder_list = []
|
||||
|
||||
# class for weakref.ref
|
||||
class Holder:
|
||||
pass
|
||||
|
||||
# return a Holder object for later unpack process
|
||||
def pack():
|
||||
res = Holder()
|
||||
weak_holder_list.append(weakref.ref(res))
|
||||
return res
|
||||
|
||||
# unpack hook
|
||||
def unpack(x):
|
||||
unpack_counter = 0
|
||||
|
||||
# re-compute all the activation inside the function when we first call unpack
|
||||
if len(storage) == 0:
|
||||
|
||||
def inner_pack(inner):
|
||||
nonlocal unpack_counter
|
||||
unpack_counter += 1
|
||||
|
||||
# If the holder went out of scope, the SavedVariable is dead and so
|
||||
# the value will never be read from the storage. Skip filling it.
|
||||
if weak_holder_list[unpack_counter - 1]() is None:
|
||||
return
|
||||
|
||||
# Use detach here to ensure we don't keep the temporary autograd
|
||||
# graph created during the second forward
|
||||
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
||||
return
|
||||
|
||||
def inner_unpack(packed):
|
||||
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
||||
|
||||
# restore rng state
|
||||
torch.set_rng_state(fwd_cpu_state)
|
||||
for parallel_mode, state in fwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(fwd_current_mode)
|
||||
|
||||
# reload arg into device if needed
|
||||
if activation_offload:
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg):
|
||||
arg = arg.to(device=device)
|
||||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||
inner_pack, inner_unpack
|
||||
):
|
||||
function(*args)
|
||||
else:
|
||||
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
function(*args)
|
||||
|
||||
if x not in storage:
|
||||
raise RuntimeError(
|
||||
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this."
|
||||
)
|
||||
|
||||
return storage[x]
|
||||
|
||||
# get device if we need to offload the activation
|
||||
if activation_offload:
|
||||
device = get_current_device()
|
||||
|
||||
# run function with pack and unpack as saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
output = function(*args)
|
||||
|
||||
# offload activation if needed
|
||||
if activation_offload:
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg):
|
||||
arg = arg.to(device="cpu")
|
||||
|
||||
return output
|
|
@ -0,0 +1,248 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import bisect
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import internlm
|
||||
|
||||
CURRENT_TIME = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = internlm.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_master_node():
|
||||
import subprocess
|
||||
|
||||
if os.getenv("SLURM_JOB_ID") is None:
|
||||
raise RuntimeError("get_master_node can only used in Slurm launch!")
|
||||
result = subprocess.check_output('scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1', shell=True)
|
||||
result = result.decode("utf8").strip()
|
||||
return result
|
||||
|
||||
|
||||
def get_process_rank():
|
||||
proc_rank = -1
|
||||
if os.getenv("SLURM_PROCID") is not None:
|
||||
proc_rank = int(os.getenv("SLURM_PROCID"))
|
||||
elif os.getenv("RANK") is not None:
|
||||
# In k8s env, we use $RANK.
|
||||
proc_rank = int(os.getenv("RANK"))
|
||||
|
||||
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
|
||||
return proc_rank
|
||||
|
||||
|
||||
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
|
||||
def _move_tensor(element):
|
||||
if not torch.is_tensor(element):
|
||||
# we expecte the data type if a list of dictionaries
|
||||
for item in element:
|
||||
if isinstance(item, dict):
|
||||
for key, value in item.items():
|
||||
assert not value.is_cuda, "elements are already on devices."
|
||||
item[key] = value.to(get_current_device()).detach()
|
||||
elif isinstance(item, list):
|
||||
for index, value in enumerate(item):
|
||||
assert not value.is_cuda, "elements are already on devices."
|
||||
item[index] = value.to(get_current_device()).detach()
|
||||
elif torch.is_tensor(item):
|
||||
if not item.is_cuda:
|
||||
item = item.to(get_current_device()).detach()
|
||||
else:
|
||||
assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}"
|
||||
if not element.is_cuda:
|
||||
element = element.to(get_current_device()).detach()
|
||||
return element
|
||||
|
||||
|
||||
def move_to_device(data):
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(get_current_device())
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_to_return = []
|
||||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_to_return.append(
|
||||
{
|
||||
k: (
|
||||
_move_tensor(v)
|
||||
if k != "inference_params"
|
||||
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
||||
)
|
||||
for k, v in element.items()
|
||||
}
|
||||
)
|
||||
else:
|
||||
data_to_return.append(_move_tensor(element))
|
||||
data = data_to_return
|
||||
elif isinstance(data, dict):
|
||||
data = {
|
||||
k: (
|
||||
_move_tensor(v)
|
||||
if k != "inference_params"
|
||||
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
return data
|
||||
|
||||
|
||||
def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
|
||||
if isinstance(norm, float):
|
||||
norm = torch.Tensor([norm])
|
||||
if move_to_cuda:
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
"""
|
||||
Returns currently selected device (gpu/cpu).
|
||||
If cuda available, return gpu, otherwise return cpu.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def get_batch_size(data):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.size(0)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if isinstance(data[0], dict):
|
||||
return data[0][list(data[0].keys())[0]].size(0)
|
||||
return data[0].size(0)
|
||||
elif isinstance(data, dict):
|
||||
return data[list(data.keys())[0]].size(0)
|
||||
|
||||
|
||||
def filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def launch_time():
|
||||
global CURRENT_TIME
|
||||
if not CURRENT_TIME:
|
||||
CURRENT_TIME = datetime.now().strftime("%b%d_%H-%M-%S")
|
||||
return CURRENT_TIME
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seed for reproducability."""
|
||||
# It is recommended to use this only when inference.
|
||||
if seed is not None:
|
||||
assert seed > 0
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
# if you are using multi-GPU.
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def conditional_context(context_manager, enable=True):
|
||||
if enable:
|
||||
with context_manager:
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class BatchSkipper:
|
||||
"""
|
||||
BatchSkipper is used to determine whether to skip the current batch_idx.
|
||||
"""
|
||||
|
||||
def __init__(self, skip_batches):
|
||||
if skip_batches == "":
|
||||
pass
|
||||
intervals = skip_batches.split(",")
|
||||
spans = []
|
||||
if skip_batches != "":
|
||||
for interval in intervals:
|
||||
if "-" in interval:
|
||||
start, end = map(int, interval.split("-"))
|
||||
else:
|
||||
start, end = int(interval), int(interval)
|
||||
if spans:
|
||||
assert spans[-1] <= start
|
||||
spans.extend((start, end + 1))
|
||||
self.spans = spans
|
||||
|
||||
def __call__(self, batch_count):
|
||||
index = bisect.bisect_right(self.spans, batch_count)
|
||||
return index % 2 == 1
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
"""
|
||||
Singleton Meta.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
else:
|
||||
assert (
|
||||
len(args) == 0 and len(kwargs) == 0
|
||||
), f"{cls.__name__} is a singleton class and a instance has been created."
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
def get_megatron_flops(
|
||||
elapsed_time_per_iter,
|
||||
checkpoint=False,
|
||||
seq_len=2048,
|
||||
hidden_size=12,
|
||||
num_layers=32,
|
||||
vocab_size=12,
|
||||
global_batch_size=4,
|
||||
global_world_size=1,
|
||||
mlp_ratio=4,
|
||||
use_swiglu=True,
|
||||
):
|
||||
"""
|
||||
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
|
||||
"""
|
||||
|
||||
checkpoint_activations_factor = 4 if checkpoint else 3
|
||||
|
||||
if use_swiglu:
|
||||
mlp_ratio = mlp_ratio * 3 / 2
|
||||
|
||||
flops_per_iteration = (
|
||||
checkpoint_activations_factor
|
||||
* (
|
||||
(8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2
|
||||
+ 4 * global_batch_size * seq_len**2 * hidden_size
|
||||
)
|
||||
) * num_layers + 6 * global_batch_size * seq_len * hidden_size * vocab_size
|
||||
|
||||
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||
return tflops
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
|
||||
LOGGER_NAME = "internlm"
|
||||
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
|
||||
LOGGER_LEVEL = "info"
|
||||
LOGGER_LEVEL_CHOICES = ["debug", "info", "warning", "error", "critical"]
|
||||
LOGGER_LEVEL_HELP = (
|
||||
"The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'"
|
||||
)
|
||||
|
||||
|
||||
def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger:
|
||||
"""Configure the logger that is used for uniscale framework.
|
||||
|
||||
Args:
|
||||
logger_name (str): used to create or get the correspoding logger in
|
||||
getLogger call. It will be "internlm" by default.
|
||||
logging_level (str, optional): Logging level in string or logging enum.
|
||||
|
||||
Returns:
|
||||
logger (logging.Logger): the created or modified logger.
|
||||
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
if logging_level not in LOGGER_LEVEL_CHOICES:
|
||||
logging_level = LOGGER_LEVEL
|
||||
print(LOGGER_LEVEL_HELP)
|
||||
|
||||
logging_level = logging.getLevelName(logging_level.upper())
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging_level)
|
||||
logger.setLevel(logging_level)
|
||||
handler.setFormatter(logging.Formatter(LOGGER_FORMAT))
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class _Timer:
|
||||
"""Timer."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name_ = name
|
||||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
self.start_time = time.time()
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
assert not self.started_, "timer has already been started"
|
||||
torch.cuda.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
torch.cuda.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
||||
def reset(self):
|
||||
"""Reset timer."""
|
||||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
|
||||
def elapsed(self, reset=True):
|
||||
"""Calculate the elapsed time."""
|
||||
started_ = self.started_
|
||||
# If the timing in progress, end it first.
|
||||
if self.started_:
|
||||
self.stop()
|
||||
# Get the elapsed time.
|
||||
elapsed_ = self.elapsed_
|
||||
# Reset the elapsed time
|
||||
if reset:
|
||||
self.reset()
|
||||
# If timing was in progress, set it back.
|
||||
if started_:
|
||||
self.start()
|
||||
return elapsed_
|
||||
|
||||
|
||||
class Timers:
|
||||
"""Group of timers."""
|
||||
|
||||
def __init__(self):
|
||||
self.timers = {}
|
||||
|
||||
def __call__(self, name):
|
||||
if name not in self.timers:
|
||||
self.timers[name] = _Timer(name)
|
||||
return self.timers[name]
|
||||
|
||||
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
|
||||
"""Write timers to a tensorboard writer"""
|
||||
# currently when using add_scalars,
|
||||
# torch.utils.add_scalars makes each timer its own run, which
|
||||
# polutes the runs list, so we just add each as a scalar
|
||||
assert normalizer > 0.0
|
||||
for name in names:
|
||||
if name in self.timers:
|
||||
value = self.timers[name].elapsed(reset=reset) / normalizer
|
||||
writer.add_scalar(f"time/{name}-time", value, iteration)
|
||||
|
||||
def log(self, names, logger, normalizer=1.0, reset=True):
|
||||
"""Log a group of timers."""
|
||||
assert normalizer > 0.0
|
||||
string = ""
|
||||
for name in names:
|
||||
if name in self.timers:
|
||||
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||||
string += " | {}: {:.2f}".format(name, elapsed_time)
|
||||
if not len(string): # pylint: disable=C1802
|
||||
return
|
||||
string = "time (ms)" + string
|
||||
|
||||
logger.info(string)
|
||||
return string
|
||||
|
||||
def debug(self, names, logger, normalizer=1.0, reset=True):
|
||||
"""Log a group of timers."""
|
||||
assert normalizer > 0.0
|
||||
string = ""
|
||||
for name in names:
|
||||
if name in self.timers:
|
||||
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||||
string += " | {}: {:.2f}".format(name, elapsed_time)
|
||||
if not len(string): # pylint: disable=C1802
|
||||
return
|
||||
string = "time (ms)" + string
|
||||
|
||||
logger.debug(string)
|
||||
return string
|
||||
|
||||
def reset(self):
|
||||
for _, t in self.timers.items():
|
||||
t.reset()
|
||||
|
||||
|
||||
megatron_timer = Timers()
|
|
@ -0,0 +1,289 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def get_model_topology(model):
|
||||
"""
|
||||
Returns:
|
||||
{
|
||||
'{name}': {'dim': int}
|
||||
}
|
||||
where name is the name of the module, and all parameters under this module are
|
||||
concatenated along the dimension 'dim'.
|
||||
"""
|
||||
|
||||
from flash_attn.modules.embedding import VocabParallelEmbedding
|
||||
|
||||
topos = {}
|
||||
for name, module in model.named_modules():
|
||||
# If it does not meet these conditions, it is shared between various tp/dp, and it is necessary to assert
|
||||
# that they are consistent.
|
||||
if isinstance(module, VocabParallelEmbedding):
|
||||
topos[name] = {"dim": 0}
|
||||
return topos
|
||||
|
||||
|
||||
def save_model_checkpoint(folder, model):
|
||||
"""
|
||||
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
||||
will not be gathered and saved separately, which is equivalent to actual sharding. The saved weight is named
|
||||
- folder
|
||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||
|
||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||
|
||||
Args:
|
||||
folder: The folder to save the model
|
||||
model: The model to be saved
|
||||
"""
|
||||
|
||||
states = model.state_dict()
|
||||
topo = get_model_topology(model)
|
||||
|
||||
if folder is not None:
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
# TODO In theory, we should also consider pp level, but since pp is generally a state across machines,
|
||||
# even if pp is not considered, it will definitely not be written on the same machine.
|
||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||
for i in range(tp_size):
|
||||
should_save_rank_pair.add((i, i % dp_size))
|
||||
|
||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=states)
|
||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||
topo_fp = os.path.join(folder, topo_fn)
|
||||
llm_save(topo_fp, saved_obj=topo)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def load_model_checkpoint(folder, model):
|
||||
"""
|
||||
There should be weights with names similar to the following under the folder.
|
||||
- folder
|
||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||
|
||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||
"""
|
||||
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
fns = get_fns(folder)
|
||||
max_pp, max_tp = 0, 0
|
||||
for fn in fns:
|
||||
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
||||
segements = os.path.splitext(fn)[0].split("_")
|
||||
max_pp = max(max_pp, int(segements[-1][2:]))
|
||||
max_tp = max(max_tp, int(segements[-2][2:]))
|
||||
|
||||
assert (
|
||||
pp_size == max_pp + 1
|
||||
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
||||
assert (
|
||||
tp_size == max_tp + 1
|
||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
fp = os.path.join(folder, should_load_name)
|
||||
states = llm_load(fp, map_location=get_current_device())
|
||||
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
if len(missing_k) != 0:
|
||||
logger.warning(f"Warning: missing keys {missing_k}")
|
||||
if len(unexpected_keys) != 0:
|
||||
logger.warning(f"Warning: unexpected keys {unexpected_keys}")
|
||||
|
||||
# avoid to cuda oom, Ref: https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189/11
|
||||
del states
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def save_optimizer_checkpoint(optim, state_path):
|
||||
"""Store the state of the optimizer to the local file system or remote OSS.
|
||||
|
||||
Args:
|
||||
optim (Optimizer)
|
||||
state_path (str): The state loading path of optimizer.
|
||||
"""
|
||||
|
||||
# TODO sanity check for optimizer type
|
||||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||
|
||||
states = optim.state_dict()
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
if gpc.get_global_rank() < optim.zero_world_size:
|
||||
llm_save(os.path.join(state_path, fp), states)
|
||||
if "zero_devide_optim_plan" in states:
|
||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||
fp_meta = os.path.join(state_path, optim.rank_unique_id)
|
||||
llm_save(fp_meta, params_per_rank_id_dict)
|
||||
else:
|
||||
llm_save(os.path.join(state_path, fp), states)
|
||||
|
||||
|
||||
def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
||||
"""
|
||||
Save checkpoint to the given folder path.
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
torch.distributed.barrier()
|
||||
folder = os.path.join(folder, str(train_state.step_count))
|
||||
logger.info(
|
||||
f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..."
|
||||
)
|
||||
|
||||
timer("save-model").start()
|
||||
save_model_checkpoint(folder=folder, model=model)
|
||||
timer("save-model").stop()
|
||||
|
||||
timer("save-optimizer").start()
|
||||
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
||||
timer("save-optimizer").stop()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
|
||||
sampler_state = train_state.batch_sampler.state_dict()
|
||||
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
||||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||
|
||||
if model_config is not None:
|
||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
timer.log(["save-model", "save-optimizer"], logger=logger)
|
||||
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
||||
|
||||
|
||||
def load_optimizer_checkpoint(folder, optim):
|
||||
"""Load the optimizer state from the local file system or remote
|
||||
object storage Service (OSS).
|
||||
|
||||
Args:
|
||||
optim (Optimizer): optimizer
|
||||
folder (str): The FS/OSS path where the optimizer will be stored.
|
||||
"""
|
||||
|
||||
fns = get_fns(folder)
|
||||
max_tp, max_pp, max_zero = 0, 0, 0
|
||||
for fn in fns:
|
||||
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
|
||||
_, tp, pp, zero = os.path.splitext(fn)[0].split("_")
|
||||
max_zero = max(max_zero, int(zero[2:]))
|
||||
max_tp = max(max_tp, int(tp[2:]))
|
||||
max_pp = max(max_pp, int(pp[2:]))
|
||||
|
||||
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
|
||||
assert (
|
||||
zero_size == max_zero + 1
|
||||
), f"The weights are save for {max_zero+1} data parallel, while current has {zero_size} zero broadcast range."
|
||||
assert (
|
||||
pp_size == max_pp + 1
|
||||
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
||||
assert (
|
||||
tp_size == max_tp + 1
|
||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||
|
||||
fp = f"optimizer_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||
fp += f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
||||
fp += f"zo{zero_rank}.pt"
|
||||
states = llm_load(os.path.join(folder, fp), map_location=get_current_device())
|
||||
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
fp_meta = os.path.join(folder, optim.rank_unique_id)
|
||||
try:
|
||||
zero_devide_optim_plan = llm_load(fp_meta)
|
||||
states.update({"zero_devide_optim_plan": zero_devide_optim_plan})
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Read zero optimzer split file '{fp_meta}', for '{e}'"
|
||||
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
||||
)
|
||||
|
||||
optim.load_state_dict(states)
|
||||
del states
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_sampler(ckpt_path: str, sampler):
|
||||
sampler_states = llm_load(os.path.join(ckpt_path, "sampler.pt"))
|
||||
sampler.load_state_dict(sampler_states)
|
||||
if gpc.is_rank_for_log():
|
||||
pstate = copy.deepcopy(sampler_states)
|
||||
pstate.pop("indices")
|
||||
pstate.pop("rng_state")
|
||||
logger.info(f"reload sampler_states:{pstate}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
|
||||
context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
|
||||
train_state.load_state_dict(context_stuffs, train_dl)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"reload train_state:{train_state}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
|
||||
scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
|
||||
if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
f"Using new learning rate {learning_rate} to replace old learn rate {scheduler_states['base_lrs'][0]}."
|
||||
)
|
||||
|
||||
base_lrs = copy.deepcopy(scheduler_states["base_lrs"])
|
||||
scheduler_states["base_lrs"] = [learning_rate] * len(scheduler_states["base_lrs"])
|
||||
if "after_scheduler_dict" in scheduler_states:
|
||||
scheduler_states["after_scheduler_dict"]["base_lrs"] = [learning_rate] * len(
|
||||
scheduler_states["after_scheduler_dict"]["base_lrs"]
|
||||
)
|
||||
|
||||
lr_scheduler.load_state_dict(scheduler_states)
|
||||
lr_scheduler.last_epoch = train_state.step_count + 1
|
||||
|
||||
ratios = [learning_rate / lr for lr in base_lrs]
|
||||
for idx, param_group in enumerate(optimizer.param_groups):
|
||||
param_group["lr"] = param_group["lr"] * ratios[idx]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def sync_model_param(model, parallel_mode):
|
||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked.
|
||||
"""
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def sync_model_param_within_tp(model):
|
||||
r"""This function is changed from colossalai, which is ``sync_model_param``.
|
||||
|
||||
We modified this function to make sure it only sync parameters within tensor parallelism
|
||||
but they are not splitted by tensor parallelism.
|
||||
This function is used to make sure parameters that are not splitted by tensor parallelism
|
||||
are the same across each tensor parallelism.
|
||||
For example, parameters like RMSNorm, LayerNorm...
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
parallel_mode = ParallelMode.TENSOR
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
if not is_model_parallel_parameter(param):
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def is_no_pp_or_last_stage():
|
||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
class Registry:
|
||||
"""This is a registry class used to register classes and modules so that a universal
|
||||
object builder can be enabled.
|
||||
|
||||
Args:
|
||||
name (str): The name of the registry.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry = dict()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def register_module(self, module_name: str):
|
||||
"""Registers a module represented in `module_class`.
|
||||
|
||||
Args:
|
||||
module_class (class): The module to be registered.
|
||||
Returns:
|
||||
class: The module to be registered, so as to use it normally if via importing.
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if the module has already been registered before.
|
||||
"""
|
||||
|
||||
assert module_name not in self._registry, f"{module_name} not found in {self.name}"
|
||||
|
||||
def decorator_wrapper(original_func):
|
||||
self._registry[module_name] = original_func
|
||||
return original_func
|
||||
|
||||
return decorator_wrapper
|
||||
|
||||
def get_module(self, module_name: str):
|
||||
"""Retrieves a module with name `module_name` and returns the module if it has
|
||||
already been registered before.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to be retrieved.
|
||||
Returns:
|
||||
:class:`object`: The retrieved module or None.
|
||||
Raises:
|
||||
NameError: Raises a NameError if the module to be retrieved has neither been
|
||||
registered directly nor as third party modules before.
|
||||
"""
|
||||
if module_name in self._registry:
|
||||
return self._registry[module_name]
|
||||
raise NameError(f"Module {module_name} not found in the registry {self.name}")
|
||||
|
||||
def has(self, module_name: str):
|
||||
"""Searches for a module with name `module_name` and returns a boolean value indicating
|
||||
whether the module has been registered directly or as third party modules before.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to be searched for.
|
||||
Returns:
|
||||
bool: A boolean value indicating whether the module has been registered directly or
|
||||
as third party modules before.
|
||||
"""
|
||||
found_flag = module_name in self._registry
|
||||
|
||||
return found_flag
|
||||
|
||||
|
||||
MODEL_INITIALIZER = Registry("model_initializer")
|
|
@ -0,0 +1,351 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
import torch
|
||||
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||
|
||||
MB = 1024**2
|
||||
|
||||
storage_manager = None
|
||||
|
||||
|
||||
def check_folder(fp: str):
|
||||
storage_manager.assert_fp_exists(fp)
|
||||
|
||||
|
||||
def get_fns(fp: str):
|
||||
return storage_manager.get_fns(fp)
|
||||
|
||||
|
||||
def llm_load(fp: str, *args, **kwargs):
|
||||
return storage_manager.load(fp, *args, **kwargs)
|
||||
|
||||
|
||||
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
|
||||
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
NORMAL_CHECKPOINT = 1
|
||||
|
||||
|
||||
class StorageClient:
|
||||
"""
|
||||
StorageClient as a client for s3 storage access.
|
||||
"""
|
||||
|
||||
def __init__(self, handler) -> None:
|
||||
self.handler = handler
|
||||
|
||||
@staticmethod
|
||||
def load(client, load_path: str, map_location):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(client):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_fns(client):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Boto3MetaInfo:
|
||||
def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
|
||||
self.client = client
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.file_path = file_path
|
||||
|
||||
|
||||
class LocalMetaInfo:
|
||||
def __init__(self, client: StorageClient, dest_path: str) -> None:
|
||||
self.client = client
|
||||
self.dest_path = dest_path
|
||||
|
||||
|
||||
def unpack_meta(meta):
|
||||
args = []
|
||||
for k, v in meta.__dict__.items():
|
||||
if k == "endpoint":
|
||||
continue
|
||||
args.append(v)
|
||||
return args
|
||||
|
||||
|
||||
def compute_file_md5_by_chunk(file_name: str):
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(file_name, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_boto3_meta(fp: str) -> Boto3MetaInfo:
|
||||
assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
|
||||
parts = fp.lstrip("s3://").split(os.path.sep)
|
||||
match = boto3_url_re.match(parts[0])
|
||||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||
bucket_name, endpoint = match.group(1), match.group(2)
|
||||
endpoint = "http://" + endpoint + ":80"
|
||||
return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))
|
||||
|
||||
|
||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||
return LocalMetaInfo(None, fp)
|
||||
|
||||
|
||||
class Boto3Client(StorageClient):
|
||||
"""
|
||||
Boto3Client
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
s3_endpoint_url: str,
|
||||
use_threads: int = True,
|
||||
multipart_chunksize=8 * MB,
|
||||
max_concurrency: int = 10,
|
||||
multipart_threshold=100 * MB,
|
||||
) -> None:
|
||||
"""S3 object/file storage management class
|
||||
|
||||
Args:
|
||||
s3_access_keys_id (str): S3 access key ID.
|
||||
s3_secret_access_key (str): S3 secret access key.
|
||||
use_threads (bool, optional): Whether to enable multipart. Defaults to True.
|
||||
multipart_chunksize (_type_, optional): Defaults to 8*MB.
|
||||
max_concurrency (int, optional): Defaults to 10.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Connection failures caused by misconfiguration or network problems.
|
||||
"""
|
||||
super().__init__(boto3)
|
||||
self.botocore = botocore
|
||||
try:
|
||||
s3_access_key_id = os.environ["S3_ACCESS_KEY_ID"]
|
||||
s3_secret_access_key = os.environ["S3_SECRET_ACCESS_KEY_ID"]
|
||||
except KeyError as exc:
|
||||
raise RuntimeError(
|
||||
"Please set boto3 bucket 'S3_ACCESS_KEY_ID' and 'S3_SECRET_ACCESS_KEY_ID' using environment variable!"
|
||||
) from exc
|
||||
|
||||
self.client = self.handler.client(
|
||||
"s3",
|
||||
"",
|
||||
use_ssl=False,
|
||||
verify=False,
|
||||
endpoint_url=s3_endpoint_url,
|
||||
aws_access_key_id=s3_access_key_id,
|
||||
aws_secret_access_key=s3_secret_access_key,
|
||||
)
|
||||
|
||||
self.config = self.handler.s3.transfer.TransferConfig(
|
||||
multipart_threshold=multipart_threshold,
|
||||
max_concurrency=max_concurrency,
|
||||
multipart_chunksize=multipart_chunksize,
|
||||
use_threads=use_threads,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
|
||||
assert saved_obj is not None, "saved_obj is None!"
|
||||
try:
|
||||
with io.BytesIO() as f:
|
||||
torch.save(saved_obj, f, *args, **kwargs)
|
||||
f.seek(0)
|
||||
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
||||
"""
|
||||
try:
|
||||
with io.BytesIO() as f:
|
||||
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
||||
f.seek(0)
|
||||
states = torch.load(f, *args, map_location=map_location, **kwargs)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||
) from exc
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(
|
||||
handler,
|
||||
bucket_name: str,
|
||||
fp: str,
|
||||
):
|
||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, bucket_name: str, fp: str):
|
||||
"""
|
||||
Ref: https://stackoverflow.com/questions/54314563/
|
||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||
"""
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
for obj in page["Contents"]:
|
||||
fp: str = obj["Key"]
|
||||
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
|
||||
return folder_name_list
|
||||
|
||||
|
||||
class LocalClient(StorageClient):
|
||||
"""
|
||||
Storage Client for local NFS.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None: # pylint: disable=W0613
|
||||
super().__init__(None)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert saved_obj is not None
|
||||
fp_dirname = os.path.dirname(fp)
|
||||
if not os.path.exists(fp_dirname):
|
||||
os.makedirs(fp_dirname, exist_ok=True)
|
||||
torch.save(saved_obj, fp, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(handler, fp: str, *args, map_location="cpu", **kwargs):
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(fp), f"{fp} is not found!"
|
||||
with open(fp, "rb") as f:
|
||||
states = torch.load(f, map_location=map_location, *args, **kwargs)
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(folder), folder
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(folder), f"folder '{folder}' not exists!"
|
||||
fns = os.listdir(folder)
|
||||
return fns
|
||||
|
||||
@staticmethod
|
||||
def delete_obj(handler, fp: str):
|
||||
assert isinstance(handler, LocalClient)
|
||||
if not os.path.isdir(fp):
|
||||
os.remove(fp)
|
||||
|
||||
|
||||
class StorageManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Storage Manager for saving or loading checkpoint.
|
||||
"""
|
||||
|
||||
BACKEND_TYPE = {"boto3", "local"}
|
||||
BACKEND_INIT_METHOD = {
|
||||
"boto3": Boto3Client,
|
||||
"local": LocalClient,
|
||||
}
|
||||
CLI_DICT = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||
"""
|
||||
example:
|
||||
local:/path/to/checkpoint
|
||||
boto3:s3://model_weights/0331/120bi
|
||||
|
||||
Args:
|
||||
path (str): _description_
|
||||
"""
|
||||
try:
|
||||
backend, path = path.split(":", maxsplit=1)
|
||||
except Exception as exc:
|
||||
raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
|
||||
|
||||
init_args = (None,)
|
||||
if backend == "local":
|
||||
meta_info = get_local_meta(path)
|
||||
backend_key = backend
|
||||
elif backend == "boto3":
|
||||
meta_info = get_boto3_meta(path)
|
||||
backend_key = backend + ":" + meta_info.endpoint
|
||||
init_args = (meta_info.endpoint,)
|
||||
if (
|
||||
"http_proxy" in os.environ
|
||||
or "https_proxy" in os.environ
|
||||
or "HTTP_PROXY" in os.environ
|
||||
or "HTTPS_PROXY" in os.environ
|
||||
):
|
||||
raise RuntimeWarning(
|
||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||
the proxy may make boto3 unavailable or affect performance."
|
||||
)
|
||||
|
||||
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
|
||||
|
||||
# boto3 backend need special treatment.
|
||||
if backend_key not in StorageManager.CLI_DICT:
|
||||
StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})
|
||||
|
||||
meta_info.client = StorageManager.CLI_DICT[backend_key]
|
||||
|
||||
return meta_info
|
||||
|
||||
def assert_fp_exists(self, folder) -> None:
|
||||
meta = self._get_client(path=folder)
|
||||
meta.client.assert_fp_exists(*unpack_meta(meta))
|
||||
|
||||
def get_fns(self, folder) -> List[str]:
|
||||
meta = self._get_client(path=folder)
|
||||
return meta.client.get_fns(*unpack_meta(meta))
|
||||
|
||||
def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
|
||||
meta = self._get_client(path=save_path)
|
||||
|
||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||
|
||||
def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:
|
||||
|
||||
meta = self._get_client(path=load_path)
|
||||
return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)
|
||||
|
||||
def delete_obj(self, fp: str):
|
||||
meta = self._get_client(path=fp)
|
||||
meta.client.delete_obj(*unpack_meta(meta))
|
||||
|
||||
|
||||
storage_manager = StorageManager()
|
|
@ -0,0 +1,14 @@
|
|||
transformers>=4.25.1
|
||||
numpy
|
||||
tqdm
|
||||
psutil
|
||||
packaging
|
||||
pre-commit
|
||||
ninja
|
||||
gputil
|
||||
pytest
|
||||
packaging
|
||||
boto3
|
||||
botocore
|
||||
torch-scatter
|
||||
-f https://data.pyg.org/whl/torch-1.13.0+cu117.html
|
|
@ -0,0 +1,4 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu117
|
||||
torch==1.13.1+cu117
|
||||
torchvision==0.14.1+cu117
|
||||
torchaudio==0.13.1
|
|
@ -0,0 +1,2 @@
|
|||
sonar.projectKey=InternLM
|
||||
sonar.python.version=3.6,3.7,3.8,3.9,3.10
|
|
@ -0,0 +1,53 @@
|
|||
本目录提供辅助模型训练的一些工具,文件结构如下所示:
|
||||
```bash
|
||||
├── transformers # 适配hugging face的transformers的一些工具
|
||||
│ ├── configuration_internlm.py # config适配工具
|
||||
│ ├── modeling_internlm.py # model适配工具
|
||||
│ └── tokenization_internlm.py # tokenizer适配工具
|
||||
├── convert2hf.py # 模型适配hugging face工具
|
||||
└── tokenizer.py # 将原始数据转换成bin和meta文件的工具
|
||||
```
|
||||
|
||||
# tokenizer.py
|
||||
生成原始数据的`bin`和`meta`文件需要使用`tokenizer`,我们通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前我们提供了`V7.model`来生成tokens。若想使用不同的模型,可直接修改`tokernizer.py`中的模型参数路径。
|
||||
|
||||
我们可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`raw_data_name`表示原始数据集的文件名称,`input_file_type`表示原始数据集的文件格式,我们目前支持`txt`、`json`和`jsonl`这三种格式,`bin`表示生成的`bin`文件的保存路径。
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
|
||||
```
|
||||
|
||||
下面是一个数据处理的例子(这里只给出了`txt`格式的数据处理例子,`json`和`jsonl`的数据处理流程和`txt`的完全一致):
|
||||
|
||||
给定一个包含原始数据集的文件`raw_data.txt`,原始数据集如下所示:
|
||||
```bash
|
||||
感恩生活中的每一个细节,才能真正体会到幸福的滋味。
|
||||
梦想是人生的动力源泉,努力追逐,才能实现自己的目标。
|
||||
学会宽容和理解,才能建立真正和谐的人际关系。
|
||||
```
|
||||
|
||||
接下来,我们可以通过运行以下命令来生成`bin`和`meta`文件:
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
|
||||
```
|
||||
|
||||
需要注意的是,生成的`bin`文件需要保存在`cn`或者`en`或者`code`或者`ja`或者`ar`或者`kaoshi`这五个目录下,以区分数据集的类型。
|
||||
|
||||
其中,`cn`表示中文数据集;`en`表示英文数据集;`code`表示代码数据集;`ja`表示日语数据集;`ar`表示阿拉伯语数据集;`kaoshi`表示考试数据集。
|
||||
|
||||
生成的bin文件的格式如下:
|
||||
```python
|
||||
{"tokens": [73075, 75302, 69522, 69022, 98899, 67713, 68015, 81269, 74637, 75445, 99157]}
|
||||
{"tokens": [69469, 60355, 73026, 68524, 60846, 61844, 98899, 67775, 79241, 98899, 67713, 67800, 67453, 67838, 99157]}
|
||||
{"tokens": [68057, 79017, 60378, 68014, 98899, 67713, 67990, 68015, 70381, 67428, 61003, 67622, 99157]}
|
||||
```
|
||||
`bin`文件中的每一行均对应原始数据集中的每一个句子,表示每个句子的`token`(下文将用sequence指定)。
|
||||
|
||||
生成的`meta`文件的格式如下:
|
||||
```bash
|
||||
(0, 11), (90, 15), (208, 13)
|
||||
```
|
||||
在`meta`文件中,每个元组对应着`bin`文件中每一个`sequence`的元信息。其中,元组的第一个元素表示每个`sequence`在所有`sequence`中的`starting index`,第二个元素表示每个`sequence`中有多少个`tokens`。
|
||||
|
||||
例如,对于第一个`sequence`,`starting index`为 0,有 11 个`tokens`;对于第二个`sequence`,由于第一个`sequence`转换为`string`后的长度为`89`,因此它的`starting index`为 90,有 15 个`tokens`。
|
||||
|
||||
`json`和`jsonl`类型的文件的`bin`和`meta`文件格式和`txt`一致,此处不再赘叙。
|
|
@ -0,0 +1,50 @@
|
|||
This directory provide some tools for model training with the following file structure.
|
||||
```bash
|
||||
├── transformers # tools for adapting Hugging Face's transformers
|
||||
│ ├── configuration_internlm.py # tools for adapting config
|
||||
│ ├── modeling_internlm.py # tools for adapting model
|
||||
│ └── tokenization_internlm.py # tools for adapting tokenizer
|
||||
├── convert2hf.py # tools for adapting models to Hugging Face's format
|
||||
└── tokenizer.py # tools for generating `bin` and `meta` file for raw data
|
||||
```
|
||||
|
||||
# tokenizer.py
|
||||
We need to use a `tokenizer` to generate `bin` and `meta` files for raw data. We import the tokenizer model by specifying the model weight path in `tools/tokenizer.py`. Currently, we provide `V7.model` to generate tokens. If you want to use a different model, you can modify the model weight path in `tokenizer.py` directly.
|
||||
|
||||
We can run the following command to generate `bin` and `meta` files for raw data, where the parameter `raw_data_name` indicates the file name of raw data, `input_file_type` denotes the raw data format, which should be `txt`, `json` and `jsonl`, and `bin` indicates the path to save the generated `bin` file.
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
|
||||
```
|
||||
|
||||
An example of data processing in `txt` format is given here (the data processing for `json` and `jsonl` is identical to that for `txt`).
|
||||
|
||||
Given a file `raw_data.txt` containg raw data with the following content.
|
||||
```bash
|
||||
Appreciate every detail in life to truly taste the flavor of happiness.
|
||||
Dreams are the source of life’s motivation. Pursue them diligently to achieve your goals.
|
||||
Learn to be tolerant and understanding to establish truly harmonious interpersonal relationships.
|
||||
```
|
||||
Next, we can run the following command to generate `bin` and `meta` files for raw data.
|
||||
```bash
|
||||
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
|
||||
```
|
||||
|
||||
It should be noted that the generated `bin` files should be placed in one of the following directories to clarify the data type: `cn`(Chinese), `en`(English), `code`(code data), `ja`(Japanese), `ar`(Arabic) and `kaoshi`(kaoshi data).
|
||||
|
||||
The format of generated `bin` file is as follows.
|
||||
```python
|
||||
{"tokens": [98655, 2317, 2922, 6649, 1595, 7856, 435, 2424, 442, 9556, 12807, 410, 17313, 446, 23331, 95746]}
|
||||
{"tokens": [98655, 302, 1383, 269, 657, 410, 2687, 446, 2424, 98667, 269, 25220, 281, 523, 1874, 492, 1248, 38127, 4563, 442, 11227, 829, 8980, 95746]}
|
||||
{"tokens": [98655, 24190, 442, 517, 15013, 649, 454, 8793, 442, 5849, 9556, 17917, 1369, 1084, 29890, 12021, 95746]}
|
||||
```
|
||||
In the generated `bin` file, each line (`sequence`) corresponds to the `tokens` for each sentence in the raw data.
|
||||
|
||||
The format of generated `meta` file in as follows.
|
||||
```bash
|
||||
(0, 16), (110, 24), (262, 17)
|
||||
```
|
||||
Each tuple in the `meta` file represents the meta information of each `sequence` where the first element in the tuple indicates the `starting index` of each `sequence` among all `sequences` and the second element indicates the amount of `tokens` for each `sequence`.
|
||||
|
||||
For example, the `starting index` is 0 for the first `sequence` with 16 `tokens`. Since the length of `sequence` in `string` format is 109, the `starting index` is 110. And the number of `tokens` of the sencond `sequence` is 24.
|
||||
|
||||
The `bin` and `meta` file formats for `json` and `jsonl` type files are the same as for `txt`, so we won't go over them here.
|
|
@ -0,0 +1,171 @@
|
|||
import argparse
|
||||
import json
|
||||
import sentencepiece as spm
|
||||
from tqdm import tqdm
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
|
||||
def process(dataset_path, sp_model):
|
||||
"""Process data sample from input dataset
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path of dataset json file.
|
||||
sp_model (str): Path of tokenizer.
|
||||
|
||||
Yields:
|
||||
tuple: dumped processed data sample and length of tokens.
|
||||
"""
|
||||
|
||||
dataset = json.load(open(dataset_path))
|
||||
|
||||
for data in dataset:
|
||||
yield tokenize(get_chat_format_data(data), sp_model)
|
||||
|
||||
|
||||
def get_chat_format_data(ori_data):
|
||||
"""Format original data
|
||||
|
||||
Args:
|
||||
ori_data (dict): input data sample.
|
||||
|
||||
Returns:
|
||||
dict: data sample with chat format.
|
||||
"""
|
||||
input_str = ori_data['input']
|
||||
instruction_str = ori_data['instruction']
|
||||
output_str = ori_data['output']
|
||||
data = dict()
|
||||
if input_str != "":
|
||||
data['user'] = f'<|User|>:{instruction_str}\n{input_str}'
|
||||
else:
|
||||
data['user'] = f'<|User|>:{instruction_str}'
|
||||
data['bot'] = f'<|Bot|>:{output_str}'
|
||||
return data
|
||||
|
||||
|
||||
def tokenize(sample, sp_model):
|
||||
"""Tokenize input dataset
|
||||
|
||||
Args:
|
||||
sample (dict): Input data sample.
|
||||
sp_model (str): Path of tokenizer.
|
||||
|
||||
Returns:
|
||||
tuple: dumped processed data sample and length of tokens.
|
||||
"""
|
||||
special_tokens_map = {'<eoh>': 103167, '<eoa>': 103166, 'nl_id': 13}
|
||||
token_ids = [sp_model.bos_id()]
|
||||
human_s = sample['user']
|
||||
ass_s = sample['bot']
|
||||
|
||||
human_ids = sp_model.encode(human_s) + [
|
||||
special_tokens_map["<eoh>"], special_tokens_map['nl_id']
|
||||
]
|
||||
human_ids_ignore = [-token_id for token_id in human_ids]
|
||||
|
||||
ass_template_ids = sp_model.encode('<|Assistant|>:')
|
||||
ass_template_ids_ignore = [-token_ids for token_ids in ass_template_ids]
|
||||
ass_ids = ass_template_ids_ignore + sp_model.encode(ass_s[14:]) + [
|
||||
special_tokens_map["<eoa>"], special_tokens_map['nl_id']
|
||||
]
|
||||
|
||||
token_ids += human_ids_ignore + ass_ids
|
||||
if len(token_ids) > 2047:
|
||||
token_ids = token_ids[:2047]
|
||||
token_ids += [sp_model.eos_id()]
|
||||
line = str.encode(json.dumps({'tokens': token_ids}) + '\n')
|
||||
return line, len(token_ids)
|
||||
|
||||
|
||||
def dump_bin_meta_bin(samples, path, split_ratio=0.1):
|
||||
"""Dump processed dataset
|
||||
|
||||
Args:
|
||||
samples (dict): Input data sample.
|
||||
path (str): Path for output dataset.
|
||||
split_ratio (float): Ratio for validation dataset splitting.
|
||||
Default to: 0.1.
|
||||
|
||||
Returns:
|
||||
tuple: number of train/valid tokens of processed dataset,
|
||||
number of train/valid samples of processed dataset.
|
||||
"""
|
||||
|
||||
train_path = osp.join(path, 'train/en/')
|
||||
valid_path = osp.join(path, 'valid/en/')
|
||||
train_dir = Path(train_path)
|
||||
valid_dir = Path(valid_path)
|
||||
train_dir.mkdir(exist_ok=True, parents=True)
|
||||
valid_dir.mkdir(exist_ok=True, parents=True)
|
||||
train_f = open(train_dir.joinpath('dataset.bin'), 'wb')
|
||||
valid_f = open(valid_dir.joinpath('dataset.bin'), 'wb')
|
||||
|
||||
train_tokens = 0
|
||||
valid_tokens = 0
|
||||
last_train_position = 0
|
||||
last_valid_position = 0
|
||||
train_samples = 0
|
||||
valid_samples = 0
|
||||
train_meta = []
|
||||
valid_meta = []
|
||||
|
||||
sample_length = len(samples)
|
||||
np.random.seed(0)
|
||||
valid_indices = np.random.choice(
|
||||
range(sample_length), int(sample_length * split_ratio)).tolist()
|
||||
|
||||
count = -1
|
||||
for line, token_num in samples:
|
||||
count += 1
|
||||
if count in valid_indices:
|
||||
valid_tokens += token_num
|
||||
valid_f.write(line)
|
||||
valid_meta.append((last_valid_position, token_num))
|
||||
last_valid_position += len(line)
|
||||
valid_samples += 1
|
||||
else:
|
||||
train_tokens += token_num
|
||||
train_f.write(line)
|
||||
train_meta.append((last_train_position, token_num))
|
||||
last_train_position += len(line)
|
||||
train_samples += 1
|
||||
|
||||
train_f.close()
|
||||
valid_f.close()
|
||||
np.save(open(train_dir.joinpath('dataset.bin.meta'), 'wb'), train_meta)
|
||||
np.save(open(valid_dir.joinpath('dataset.bin.meta'), "wb"), valid_meta)
|
||||
|
||||
return train_tokens, valid_tokens, train_samples, valid_samples
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'dataset_path', type=str, help='path of dataset json file')
|
||||
parser.add_argument(
|
||||
'output_path', type=str, help='path of processed dataset')
|
||||
parser.add_argument('tokenizer_path', type=str, help='path of tokenizer')
|
||||
parser.add_argument(
|
||||
'--split_ratio',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='ratio for validation dataset splitting')
|
||||
|
||||
args = parser.parse_args()
|
||||
sp_model = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
|
||||
split_ratio = args.split_ratio
|
||||
samples = []
|
||||
|
||||
dataset = process(args.dataset_path, sp_model)
|
||||
for sample in tqdm(dataset):
|
||||
samples.append(sample)
|
||||
|
||||
train_tokens, valid_tokens, train_samples, valid_samples = \
|
||||
dump_bin_meta_bin(samples, args.output_path, args.split_ratio)
|
||||
print(f'number of train dataset: {train_samples}, '
|
||||
'number of train dataset token: {train_tokens}')
|
||||
print(f'number of validation dataset: {valid_samples}, '
|
||||
'number of validation dataset token: {valid_tokens}')
|
|
@ -0,0 +1,174 @@
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from modeling_internlm import InternLMConfig, InternLMForCausalLM
|
||||
from tokenization_internlm import InternLMTokenizer
|
||||
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
}
|
||||
|
||||
|
||||
def convert2hf(model_config, states_tp_pps):
|
||||
folder = f"/dev/shm/wait_to_upload_weight_tmp_{random.random()}/"
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
try:
|
||||
states = merge_pp(states_tp_pps)[0]
|
||||
|
||||
if "embedding.word_embeddings.weight" in states:
|
||||
embedding_key = "embedding.word_embeddings.weight"
|
||||
elif "embedding.weight" in states:
|
||||
embedding_key = "embedding.weight"
|
||||
else:
|
||||
print("Check embedding states'names in below:", flush=True)
|
||||
print(list(states.keys()), flush=True)
|
||||
|
||||
dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"]
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
|
||||
current_states = {}
|
||||
|
||||
current_states["model.embed_tokens.weight"] = states.pop(embedding_key)
|
||||
current_states["model.norm.weight"] = states.pop("norm.weight")
|
||||
current_states["lm_head.weight"] = states.pop("head.weight")
|
||||
|
||||
for i in range(model_config["num_layers"]):
|
||||
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq")
|
||||
|
||||
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
|
||||
3, model_config["num_attention_heads"], -1, model_config["hidden_size"]
|
||||
)
|
||||
bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1)
|
||||
|
||||
current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape(
|
||||
-1, model_config["hidden_size"]
|
||||
)
|
||||
current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1)
|
||||
current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape(
|
||||
-1, model_config["hidden_size"]
|
||||
)
|
||||
current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1)
|
||||
current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape(
|
||||
-1, model_config["hidden_size"]
|
||||
)
|
||||
current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1)
|
||||
|
||||
current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop(
|
||||
f"blocks.{i}.mixer.out_proj.weight"
|
||||
)
|
||||
current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias")
|
||||
|
||||
current_states[f"model.layers.{i}.mlp.gate_proj.weight"] = states.pop(f"blocks.{i}.mlp.w1.weight")
|
||||
current_states[f"model.layers.{i}.mlp.down_proj.weight"] = states.pop(f"blocks.{i}.mlp.w3.weight")
|
||||
current_states[f"model.layers.{i}.mlp.up_proj.weight"] = states.pop(f"blocks.{i}.mlp.w2.weight")
|
||||
|
||||
current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight")
|
||||
current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight")
|
||||
current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
|
||||
config = InternLMConfig(
|
||||
hidden_size=model_config["hidden_size"],
|
||||
intermediate_size=compute_intermediate_size(model_config["hidden_size"]),
|
||||
num_attention_heads=model_config["num_attention_heads"],
|
||||
num_hidden_layers=model_config["num_layers"],
|
||||
rms_norm_eps=1e-06,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
if model_config["vocab_size"] != -1:
|
||||
config.vocab_size = model_config["vocab_size"]
|
||||
|
||||
config.save_pretrained(folder)
|
||||
torch.save(current_states, os.path.join(folder, "pytorch_model.bin"))
|
||||
|
||||
model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
del model.config._name_or_path
|
||||
|
||||
finally:
|
||||
shutil.rmtree(folder)
|
||||
|
||||
return config, model
|
||||
|
||||
|
||||
def compute_intermediate_size(n):
|
||||
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
||||
|
||||
|
||||
def merge_pp(states_tp_pp):
|
||||
max_tp = len(states_tp_pp)
|
||||
max_pp = len(states_tp_pp[0])
|
||||
|
||||
full_states = []
|
||||
for tp in range(max_tp):
|
||||
layer_shift = 0
|
||||
|
||||
tp_states = {}
|
||||
for pp in range(max_pp):
|
||||
_layer_shift = 0
|
||||
states = states_tp_pp[tp][pp]
|
||||
keys = list(states.keys())
|
||||
for key in keys:
|
||||
match = re.search("\.\d+\.", key)
|
||||
if match is not None:
|
||||
s, e = match.span()
|
||||
layer_idx = int(key[s + 1 : e - 1]) + layer_shift
|
||||
_layer_shift = max(_layer_shift, int(key[s + 1 : e - 1]))
|
||||
name = key[:s] + f".{layer_idx}." + key[e:]
|
||||
tp_states[name] = states[key]
|
||||
else:
|
||||
tp_states[key] = states[key]
|
||||
layer_shift += _layer_shift + 1
|
||||
full_states.append({(key[6:] if key.startswith("model.") else key): value for key, value in tp_states.items()})
|
||||
return full_states
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_folder', type=str, default='~/test/') # 需要转换为hf格式的checkpoint文件夹
|
||||
parser.add_argument('--tgt_folder', type=str, default='~/output/') # 存放转换后checkpoint的目标文件夹
|
||||
parser.add_argument('--tokenizer', type=str, default='~/test/tokenizer.model') # Tokenizer 文件的路径
|
||||
args = parser.parse_args()
|
||||
|
||||
def load(fp):
|
||||
with open(fp, "rb") as f:
|
||||
pt_data = torch.load(f, map_location="cpu")
|
||||
return pt_data
|
||||
|
||||
folder = args.src_folder
|
||||
target_folder = args.tgt_folder
|
||||
model_config = load(os.path.join(folder, "model_config.pt"))
|
||||
|
||||
fns = list(os.listdir(folder))
|
||||
|
||||
model_fns = []
|
||||
for fn in fns:
|
||||
if fn.startswith("model_t") and not fn.endswith("md5"):
|
||||
model_fns.append(fn)
|
||||
|
||||
max_tp, max_pp = -1, -1
|
||||
for fn in model_fns:
|
||||
_, tp, pp = os.path.splitext(fn)[0].split("_")
|
||||
max_pp = max(max_pp, int(pp[2:]) + 1)
|
||||
max_tp = max(max_tp, int(tp[2:]) + 1)
|
||||
|
||||
states_tp_pps = [[]]
|
||||
|
||||
for pp in range(max_pp):
|
||||
model_name = f"model_tp0_pp{pp}.pt"
|
||||
states = load(os.path.join(folder, model_name))
|
||||
states_tp_pps[0].append(states)
|
||||
|
||||
config, model = convert2hf(model_config, states_tp_pps)
|
||||
|
||||
os.makedirs(target_folder, exist_ok=True)
|
||||
model.save_pretrained(target_folder, max_shard_size="20GB")
|
||||
|
||||
tokenizer = InternLMTokenizer(args.tokenizer)
|
||||
tokenizer.save_pretrained(target_folder)
|
|
@ -0,0 +1,194 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from termcolor import colored
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, "V7.model")
|
||||
tokenizer = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
|
||||
def write_bin(context: str, path: str) -> None:
|
||||
"""
|
||||
Write bin file.
|
||||
|
||||
Args:
|
||||
context (str): the context of raw file.
|
||||
path (str): the path for output bin file.
|
||||
|
||||
Example:
|
||||
>>> write_bin("今天天气晴朗适合出门散步", "out.bin") # the output file format is 'txt'
|
||||
>>> out.bin
|
||||
>>> {"tokens": [67577, 69095, 63010, 61770, 67783, 69301, 74732]}
|
||||
"""
|
||||
# encode the context into tokens, which is a list, eg. [67577, 69095, 63010, 61770, 67783, 69301, 74732]
|
||||
tokens = tokenizer.encode(context)
|
||||
# transfer the list into dic, key is str 'tokens', value is tokens.
|
||||
# eg. {"tokens": [67577, 69095, 63010, 61770, 67783, 69301, 74732]}
|
||||
data = dict(tokens=tokens)
|
||||
# encode the data into bytes to save
|
||||
saved_bin = str.encode(json.dumps(data) + "\n")
|
||||
|
||||
# write bytes into bin path
|
||||
with open(path, "ab") as f:
|
||||
f.write(saved_bin)
|
||||
|
||||
|
||||
def prepare_meta(bin_file_path: str):
|
||||
"""
|
||||
Prepare metadata for the given bin file.
|
||||
|
||||
Args:
|
||||
bin_file_path (str): the bin file path.
|
||||
"""
|
||||
meta = []
|
||||
cur = 0
|
||||
with open(bin_file_path, "rb") as f:
|
||||
while True:
|
||||
# read lines
|
||||
line = f.readline()
|
||||
# if line is empty, then break
|
||||
if line == b"":
|
||||
break
|
||||
# obtain the token amount of each line
|
||||
length = len(json.loads(line)["tokens"])
|
||||
# meta is a list of tuple(cur, length)
|
||||
# cur: the start index of each line
|
||||
# length: the token amount of each line
|
||||
meta.append((cur, length))
|
||||
# update the cur to generate the meta information of next line
|
||||
cur += len(line)
|
||||
print(meta)
|
||||
# define path of the generated meta file
|
||||
meta_fp = bin_file_path + ".meta"
|
||||
# save the generated meta information
|
||||
with open(meta_fp, "wb") as f:
|
||||
meta = np.array(meta, dtype=np.int32)
|
||||
np.save(f, meta)
|
||||
|
||||
|
||||
def txt2bin(txt_file_path: str, bin_file_path: str):
|
||||
"""
|
||||
Read content from txt file and write to bin file
|
||||
|
||||
Args:
|
||||
txt_file_path (str): txt file path.
|
||||
bin_file_path (str): output bin file path.
|
||||
"""
|
||||
# Check if the txt file exists
|
||||
if not os.path.isfile(txt_file_path):
|
||||
warnings.warn(colored(f"{txt_file_path} does not exist.", "red"))
|
||||
return
|
||||
|
||||
try:
|
||||
# Open the text file
|
||||
with open(txt_file_path, "r") as txt_file:
|
||||
for line in txt_file:
|
||||
# Strip any leading/trailing whitespace
|
||||
stripped_line = line.strip()
|
||||
if stripped_line:
|
||||
# Pass each line to the write_bin function
|
||||
write_bin(stripped_line, bin_file_path)
|
||||
|
||||
print(colored(f"Successfully converted {txt_file_path} to {bin_file_path}", "green"))
|
||||
|
||||
except Exception as e:
|
||||
print(colored(f"Error while converting {txt_file_path} to {bin_file_path}: {str(e)}", "red"))
|
||||
|
||||
|
||||
def json2bin(json_file_path: str, bin_file_path: str):
|
||||
"""
|
||||
Read content from json file and write to bin file
|
||||
|
||||
Args:
|
||||
json_file_path (str): json file path.
|
||||
bin_file_path (str): output bin file path.
|
||||
"""
|
||||
|
||||
if not os.path.isfile(json_file_path):
|
||||
warnings.warn(colored(f"{json_file_path} does not exist.", "red"))
|
||||
return
|
||||
|
||||
try:
|
||||
# load json file
|
||||
with open(json_file_path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
# assuming data is a list of dictionaries
|
||||
for record in data:
|
||||
# the type of record is dict, transfer the dict into str
|
||||
context = json.dumps(record)
|
||||
# encode the str and write into bin
|
||||
write_bin(context, bin_file_path)
|
||||
|
||||
print(colored(f"Successfully converted {json_file_path} to {bin_file_path}", "green"))
|
||||
|
||||
except Exception as e:
|
||||
print(colored(f"Error while converting {json_file_path} to {bin_file_path}: {str(e)}", "red"))
|
||||
|
||||
|
||||
def jsonl2bin(jsonl_file_path: str, bin_file_path: str):
|
||||
"""
|
||||
Read content from jsonl file and write to bin file
|
||||
|
||||
Args:
|
||||
jsonl_file_path: jsonl file path.
|
||||
bin_file_path: bin file path.
|
||||
"""
|
||||
|
||||
if not os.path.isfile(jsonl_file_path):
|
||||
warnings.warn(colored(f"{jsonl_file_path} does not exist.", "red"))
|
||||
return
|
||||
|
||||
try:
|
||||
with open(jsonl_file_path, "r") as jsonl_file:
|
||||
for line in jsonl_file:
|
||||
# encode the str and write into bin
|
||||
write_bin(line, bin_file_path)
|
||||
|
||||
print(colored(f"Successfully converted {jsonl_file_path} to {bin_file_path}", "green"))
|
||||
|
||||
except Exception as e:
|
||||
print(colored(f"Error while converting {jsonl_file_path} to {bin_file_path}: {str(e)}", "red"))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--raw_data_name", required=True, help="Input file name")
|
||||
parser.add_argument(
|
||||
"--input_file_type",
|
||||
choices=["txt", "json", "jsonl"],
|
||||
required=True,
|
||||
help="Input file format (either txt, json or jsonl)",
|
||||
)
|
||||
parser.add_argument("--bin", required=True, help="Path to the output bin file")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
# parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# obtain the raw data path
|
||||
input_file_path = f"{args.raw_data_name}.{args.input_file_type}"
|
||||
|
||||
# different methods for different raw data type, we only support "txt", "json" and "jsonl" data type.
|
||||
if args.input_file_type == "txt":
|
||||
txt2bin(input_file_path, args.bin)
|
||||
elif args.input_file_type == "json":
|
||||
json2bin(input_file_path, args.bin)
|
||||
elif args.input_file_type == "jsonl":
|
||||
jsonl2bin(input_file_path, args.bin)
|
||||
else:
|
||||
print(colored("Invalid input file type. Use --help for more information.", "red"))
|
||||
|
||||
# To avoid potential read/write errors, the metadata preparation follows after creating the .bin file.
|
||||
prepare_meta(args.bin)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,22 @@
|
|||
# InternLM Transformers
|
||||
|
||||
该文件夹下包含了 transformers 格式的 `InternLM` 模型。
|
||||
|
||||
## 权重转换
|
||||
|
||||
`../tools/convert2hf.py` 可以将训练保存的权重一键转换为 transformers 格式。
|
||||
|
||||
```bash
|
||||
python convert2hf.py --src_folder origin_ckpt/ --tgt_folder hf_ckpt/ --tokenizer tokenizes/tokenizer.model
|
||||
```
|
||||
|
||||
然后可以使用 `from_pretrained` 接口加载:
|
||||
|
||||
```python
|
||||
from modeling_internlm import InternLMForCausalLM
|
||||
|
||||
model = InternForCausalLM.from_pretrained("hf_ckpt/")
|
||||
```
|
||||
|
||||
|
||||
`moss_example.py` 展示了如何使用 LoRA 来在 `fnlp/moss-moon-002-sft` 数据集上进行微调的样例。
|
|
@ -0,0 +1,120 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" InternLM model configuration"""
|
||||
|
||||
from transformers.utils import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
class InternLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate an InternLM
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the InternLM-7B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`InternLMModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import InternLMModel, InternLMConfig
|
||||
|
||||
>>> # Initializing a InternLM internlm-7b style configuration
|
||||
>>> configuration = InternLMConfig()
|
||||
|
||||
>>> # Initializing a model from the internlm-7b style configuration
|
||||
>>> model = InternLMModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "internlm"
|
||||
_auto_class = "AutoConfig"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=103168,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.bias = bias
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from moss_002_sft import get_dataset, collate_fn
|
||||
|
||||
model_path = "model_path"
|
||||
data_dir = "moss_002_sft"
|
||||
data_num = -1
|
||||
test_size = 10
|
||||
train_batch_size = 1
|
||||
epochs = 5
|
||||
val_per_steps = 1000
|
||||
lr = 9e-6
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, r=32, lora_alpha=32, lora_dropout=0.1,
|
||||
target_modules=["gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
)
|
||||
|
||||
|
||||
# model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.cuda()
|
||||
|
||||
# dataset
|
||||
train_dataset, val_dataset = get_dataset(tokenizer, data_dir, num=data_num, test_size=test_size)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, 1000, epochs * len(train_dataloader)
|
||||
)
|
||||
|
||||
# train
|
||||
fp = open("output", "w")
|
||||
model.train()
|
||||
for epoch in tqdm(range(epochs), desc="Traning Epoch"):
|
||||
batch_bar = tqdm(train_dataloader, desc="Training Batch")
|
||||
for step, batch in enumerate(batch_bar):
|
||||
batch = {k:v.cuda() for k, v in batch.items()}
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
output = model(**batch)
|
||||
|
||||
loss = output.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
batch_bar.set_postfix({"loss": loss.item()})
|
||||
if (step + 1) % val_per_steps == 0:
|
||||
fp.write(f"Epoch {epoch} Batch {step}: Loss={loss.item()}\n")
|
||||
for i in tqdm(range(len(val_dataset)), desc="Generating"):
|
||||
data, label = val_dataset[i]
|
||||
prefix = tokenizer.decode(data.tolist(), skip_special_tokens=True)
|
||||
try:
|
||||
generate = model.generate(input_ids=data.unsqueeze(0).cuda(), temperature=0.7, top_k=50, do_sample=True, repetition_penalty=1.02, max_new_tokens=100, top_p=0.9)
|
||||
text = tokenizer.decode(generate[0].tolist(), skip_special_tokens=True)
|
||||
text = text.replace(prefix, "")
|
||||
fp.write(f"Prefix: {prefix}\nGenerated: {text}" + "\n---------------------------------\n")
|
||||
except Exception as e:
|
||||
fp.write(f"Prefix: {prefix}\nError: {e}" + "\n---------------------------------\n")
|
||||
fp.write("\n==============================\n")
|
||||
model.train()
|
||||
torch.cuda.empty_cache()
|
|
@ -0,0 +1,105 @@
|
|||
import os
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, Dataset as HFDataset
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
# https://github.com/OpenLMLab/MOSS/blob/main/finetune_moss.py
|
||||
def __init__(self, dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = copy.deepcopy(self.dataset[index]["input_ids"])
|
||||
no_loss_spans = copy.deepcopy(self.dataset[index]["no_loss_spans"])
|
||||
|
||||
data = torch.tensor(data, dtype=torch.long)
|
||||
label = copy.deepcopy(data)
|
||||
|
||||
for no_loss_span in no_loss_spans:
|
||||
label[no_loss_span[0] : no_loss_span[1]] = -100
|
||||
|
||||
return data, label
|
||||
|
||||
def collate_fn(batch, tokenizer):
|
||||
batch_input_ids, batch_labels = [], []
|
||||
for input_ids, label in batch:
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(label)
|
||||
|
||||
batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.eos_token_id)
|
||||
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
|
||||
|
||||
return {
|
||||
"input_ids": batch_input_ids,
|
||||
"attention_mask": (batch_input_ids == tokenizer.eos_token_id).long(),
|
||||
"labels": batch_labels
|
||||
}
|
||||
|
||||
def process(sample, tokenizer, max_len):
|
||||
chat = sample["plain_text"].split("<eoa>")[:-1]
|
||||
num_turns = sample["num_turns"]
|
||||
meta_instruction = sample["prefix"]
|
||||
|
||||
# encode instruction
|
||||
instruction_ids = tokenizer.encode(meta_instruction)
|
||||
assert isinstance(instruction_ids, list), instruction_ids
|
||||
assert len(instruction_ids) > 0, len(instruction_ids)
|
||||
input_ids = copy.deepcopy(instruction_ids)
|
||||
# We do not calculate loss for instruction.
|
||||
no_loss_spans = [(0, len(instruction_ids))]
|
||||
|
||||
for i in range(num_turns):
|
||||
# Collect dialogues
|
||||
cur_turn_ids = []
|
||||
cur_no_loss_spans = []
|
||||
# Add to cur_turn_ids
|
||||
cur_turn_ids.extend(tokenizer.encode(chat[i] + "<eoa>"))
|
||||
# if key == 'Tool Responses':
|
||||
# # The format tokens (<|Results|>:...<eor>\n) should have losses.
|
||||
# cur_no_loss_spans.append((len(input_ids + cur_turn_ids) + 5, len(input_ids + cur_turn_ids + cur_ids) - 2))
|
||||
if len(input_ids + cur_turn_ids) > max_len:
|
||||
# Too long, break
|
||||
break
|
||||
# Extend input_ids
|
||||
input_ids.extend(cur_turn_ids)
|
||||
no_loss_spans.extend(cur_no_loss_spans)
|
||||
|
||||
if len(input_ids) == len(instruction_ids):
|
||||
# No dialogue, return
|
||||
return {"input_ids": [], "no_loss_spans": []}
|
||||
else:
|
||||
return {"input_ids": input_ids, "no_loss_spans": no_loss_spans}
|
||||
|
||||
|
||||
def load_data(save_dir, tokenizer, max_len, num=-1) -> HFDataset:
|
||||
if os.path.exists(save_dir):
|
||||
print(f"Loading moss-002-sft from {save_dir}")
|
||||
else:
|
||||
print(f"Loading moss-002-sft from datasets")
|
||||
moss_sft = load_dataset("fnlp/moss-002-sft-data", split="train")
|
||||
moss_sft = moss_sft.map(lambda x:process(x, tokenizer, max_len), num_proc=10)
|
||||
moss_sft = moss_sft.filter(lambda x:len(x["input_ids"]) != 0)
|
||||
moss_sft.save_to_disk(save_dir)
|
||||
|
||||
moss_sft = HFDataset.load_from_disk(save_dir)
|
||||
if num != -1:
|
||||
moss_sft = moss_sft.select(range(num))
|
||||
print(
|
||||
f"Load successfully, total {len(moss_sft)} samples.")
|
||||
|
||||
return moss_sft
|
||||
|
||||
def get_dataset(tokenizer, save_dir, max_len=1024, num=-1, test_size=0.1):
|
||||
moss_sft_data = load_data(save_dir, tokenizer, max_len, num)
|
||||
moss_sft_split = moss_sft_data.train_test_split(test_size=test_size)
|
||||
train_dataset = SFTDataset(moss_sft_split["train"])
|
||||
val_dataset = SFTDataset(moss_sft_split["test"])
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
|
@ -0,0 +1,962 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch InternLM model."""
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.streamers import BaseStreamer
|
||||
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_internlm import InternLMConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "InternLMConfig"
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class InternLMRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
InternLMRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class InternLMRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class InternLMMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.act_fn = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class InternLMAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: InternLMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
||||
self.rotary_emb = InternLMRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class InternLMDecoderLayer(nn.Module):
|
||||
def __init__(self, config: InternLMConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = InternLMAttention(config=config)
|
||||
self.mlp = InternLMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
INTERNLM_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`InternLMConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare InternLM Model outputting raw hidden-states without any specific head on top.",
|
||||
INTERNLM_START_DOCSTRING,
|
||||
)
|
||||
class InternLMPreTrainedModel(PreTrainedModel):
|
||||
config_class = InternLMConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["InternLMDecoderLayer"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, InternLMModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
INTERNLM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare InternLM Model outputting raw hidden-states without any specific head on top.",
|
||||
INTERNLM_START_DOCSTRING,
|
||||
)
|
||||
class InternLMModel(InternLMPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: InternLMConfig
|
||||
"""
|
||||
_auto_class = "AutoModel"
|
||||
|
||||
def __init__(self, config: InternLMConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class InternLMForCausalLM(InternLMPreTrainedModel):
|
||||
_auto_class = "AutoModelForCausalLM"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = InternLMModel(config)
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
||||
|
||||
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
||||
prompt = ""
|
||||
for record in history:
|
||||
prompt += f"""<s><|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
|
||||
if len(prompt) == 0:
|
||||
prompt += "<s>"
|
||||
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
||||
return tokenizer([prompt], return_tensors="pt")
|
||||
|
||||
@torch.no_grad()
|
||||
def chat(self,
|
||||
tokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = [],
|
||||
streamer: Optional[BaseStreamer] = None,
|
||||
max_new_tokens: int = 1024,
|
||||
do_sample: bool = True,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.8,
|
||||
**kwargs):
|
||||
inputs = self.build_inputs(tokenizer, query, history)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
||||
outputs = self.generate(**inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
**kwargs)
|
||||
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
response = response.split("<eoa>")[0]
|
||||
history = history + [(query, response)]
|
||||
return response, history
|
||||
|
||||
@torch.no_grad()
|
||||
def stream_chat(self,
|
||||
tokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = [],
|
||||
max_new_tokens: int = 1024,
|
||||
do_sample: bool = True,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.8,
|
||||
**kwargs):
|
||||
class ChatStreamer(BaseStreamer):
|
||||
def __init__(self, tokenizer) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def put(self, value):
|
||||
if len(value.shape) > 1 and value.shape[0] > 1:
|
||||
raise ValueError("ChatStreamer only supports batch size 1")
|
||||
elif len(value.shape) > 1:
|
||||
value = value[0]
|
||||
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
||||
if token.strip() != "<eoa>":
|
||||
print(token, end="")
|
||||
|
||||
def end(self):
|
||||
print("")
|
||||
|
||||
return self.chat(
|
||||
tokenizer=tokenizer,
|
||||
query=query,
|
||||
streamer=ChatStreamer(tokenizer=tokenizer),
|
||||
history=history,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The InternLM Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
""",
|
||||
INTERNLM_START_DOCSTRING,
|
||||
)
|
||||
class InternLMForSequenceClassification(InternLMPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = InternLMModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
|
@ -0,0 +1,242 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tokenization classes for IntermLM."""
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {}
|
||||
|
||||
|
||||
class InternLMTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a InternLM tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
_auto_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token="</s>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
decode_with_prefix_space=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.decode_with_prefix_space = decode_with_prefix_space
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
self._no_prefix_space_tokens = None
|
||||
|
||||
""" Initialisation"""
|
||||
|
||||
@property
|
||||
def no_prefix_space_tokens(self):
|
||||
if self._no_prefix_space_tokens is None:
|
||||
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
||||
self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
|
||||
return self._no_prefix_space_tokens
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.bos_id()
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def _maybe_add_prefix_space(self, tokens, decoded):
|
||||
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
||||
return " " + decoded
|
||||
else:
|
||||
return decoded
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
out_string = self.clean_up_tokenization(out_string)
|
||||
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
||||
return out_string[1:]
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if self.add_bos_token:
|
||||
bos_token_ids = [self.bos_token_id]
|
||||
else:
|
||||
bos_token_ids = []
|
||||
|
||||
output = bos_token_ids + token_ids_0
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + token_ids_1
|
||||
|
||||
if self.add_eos_token:
|
||||
output = output + [self.eos_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||
use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
eos = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
|
@ -0,0 +1,509 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import socket
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler
|
||||
from internlm.data.collaters import packed_collate_fn
|
||||
from internlm.data.dummy_dataset import RandomDataset
|
||||
from internlm.data.packed_dataset import (
|
||||
PackedDataset,
|
||||
PackedDatasetWithoutCuSeqlen,
|
||||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
get_master_node,
|
||||
get_megatron_flops,
|
||||
get_process_rank,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import (
|
||||
load_context,
|
||||
load_model_checkpoint,
|
||||
load_optimizer_checkpoint,
|
||||
load_sampler,
|
||||
load_scheduler,
|
||||
save_checkpoint,
|
||||
)
|
||||
from internlm.utils.parallel import (
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
internlm.launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
internlm.launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
Returns: The neural network model to be trained or evaluated.
|
||||
"""
|
||||
|
||||
assert (
|
||||
not hasattr(gpc.config.parallel, "pipeline") or gpc.config.parallel.pipeline == 1
|
||||
), "Pipeline parallelism is not supported for now."
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
|
||||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
# does not influence the model weights in optimizer be different with the origin parameters.
|
||||
sync_model_param(model, parallel_mode=ParallelMode.DATA)
|
||||
|
||||
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
|
||||
# the same across tensor parallelism.
|
||||
sync_model_param_within_tp(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_train_data_loader(num_worker: int = 0):
|
||||
"""
|
||||
Generate and return the training data loader.
|
||||
|
||||
Returns: A tuple of (train_dl, dataset_types).
|
||||
"""
|
||||
|
||||
# Get the dataset types
|
||||
dataset_types = None
|
||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
# Get the sample weight dictionary
|
||||
train_folder = data_cfg.train_folder
|
||||
|
||||
if not train_folder:
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = get_packed_dataset_without_short_length(
|
||||
folder=data_cfg.train_folder,
|
||||
packed_length=data_cfg.packed_length,
|
||||
max_length_per_sample=data_cfg.seq_len,
|
||||
show_progress=dist.get_rank() == 0,
|
||||
min_length=data_cfg.min_length,
|
||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||
)
|
||||
|
||||
# partition already completed
|
||||
# assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen))
|
||||
if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)):
|
||||
datasets = [train_ds]
|
||||
else:
|
||||
datasets = train_ds.datasets
|
||||
|
||||
# Create the training dataset sampler
|
||||
train_sampler = StaticBatchSampler(
|
||||
datasets,
|
||||
batch_size=data_cfg.micro_num,
|
||||
rampup_batch_size=data_cfg.rampup_batch_size,
|
||||
micro_bsz=data_cfg.micro_bsz,
|
||||
seed=1024,
|
||||
drop_last=True,
|
||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
# Create the training data loader
|
||||
train_dl = DataLoader(
|
||||
dataset=train_ds,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=num_worker,
|
||||
pin_memory=True,
|
||||
collate_fn=train_collate_fn,
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
return train_dl, dataset_types
|
||||
|
||||
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
||||
Args:
|
||||
train_dl (torch.utils.data.DataLoader): Dataloader for training.
|
||||
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
train_state (TrainState): Current training state.
|
||||
|
||||
Returns: A batch data and the updated train_iter.
|
||||
"""
|
||||
|
||||
timer("batch-gen").start()
|
||||
try:
|
||||
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
|
||||
next(train_state.batch_sampler_iter)
|
||||
except StopIteration:
|
||||
train_iter = iter(train_dl)
|
||||
batch = next(train_iter)
|
||||
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
|
||||
next(train_state.batch_sampler_iter)
|
||||
train_state.num_consumed_samples_in_epoch = 0
|
||||
timer("batch-gen").stop()
|
||||
|
||||
batch[0].pop("type_ids", None)
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
||||
def initialize_optimizer(model: nn.Module):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Your model instance to be trained or evaluated.
|
||||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
adam_cfg = gpc.config.adam
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
||||
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
|
||||
|
||||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
success_update,
|
||||
batch_count,
|
||||
batch,
|
||||
train_state,
|
||||
optimizer,
|
||||
beta2_scheduler,
|
||||
trainer,
|
||||
start_time,
|
||||
loss,
|
||||
grad_norm,
|
||||
):
|
||||
"""
|
||||
Print some training metrics of current batch.
|
||||
"""
|
||||
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
if success_update and gpc.is_rank_for_log():
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
|
||||
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
|
||||
|
||||
num_tokens_in_batch = batch[1].nelement()
|
||||
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
|
||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
|
||||
tk_per_gpu = 0
|
||||
tk_per_gpu = round(
|
||||
num_tokens_in_batch
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
/ gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
/ (time.time() - start_time),
|
||||
2,
|
||||
)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
}
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
|
||||
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
|
||||
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
|
||||
infos["largest_length"] = max_length_in_batch # the longest input
|
||||
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
|
||||
infos["smallest_batch"] = min_samples_in_batch
|
||||
infos["adam_beta2"] = beta2_scheduler.get_beta2()
|
||||
|
||||
line = ""
|
||||
for k, v in infos.items():
|
||||
line += f"{k}={v},"
|
||||
|
||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||
line += f"fwd_bwd_time={fwd_bwd_time}"
|
||||
|
||||
logger.info(line)
|
||||
|
||||
|
||||
def main(args):
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
total_steps = gpc.config.data.total_steps
|
||||
load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
||||
# ckpt setting
|
||||
save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder
|
||||
enable_save_ckpt = gpc.config.ckpt.enable_ckpt
|
||||
checkpoint_every = gpc.config.ckpt.checkpoint_every
|
||||
|
||||
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
|
||||
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
seq_len=gpc.config.SEQ_LEN,
|
||||
hidden_size=gpc.config.model.hidden_size,
|
||||
num_layers=gpc.config.model.num_layers,
|
||||
vocab_size=gpc.config.model.vocab_size,
|
||||
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
|
||||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
)
|
||||
|
||||
# get and broadcast current time
|
||||
current_time = launch_time()
|
||||
objs = [current_time]
|
||||
dist.broadcast_object_list(objs, src=0)
|
||||
current_time = objs[0]
|
||||
|
||||
model_load_path = None
|
||||
if load_resume_ckpt_folder is not None:
|
||||
logger.info(
|
||||
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_resume_ckpt_folder
|
||||
elif load_model_only_folder is not None:
|
||||
logger.info(
|
||||
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_model_only_folder
|
||||
else:
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},"
|
||||
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
||||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
||||
# initialize the train data loader
|
||||
train_dl, _ = get_train_data_loader(num_worker=4)
|
||||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=model)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
# Loading other persistent training states.
|
||||
if load_resume_ckpt_folder is not None:
|
||||
# load lr scheduler states.
|
||||
load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(load_resume_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler)
|
||||
# load optimzier states.
|
||||
if load_optimizer:
|
||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
||||
|
||||
# initialize trainer
|
||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
)
|
||||
|
||||
# initialize the batch skipper
|
||||
batch_skipper = BatchSkipper(skip_batches)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# transfer the train data loader into train data iterator
|
||||
train_iter = iter(train_dl)
|
||||
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
timer("fwd-bwd").stop()
|
||||
assert loss is not None
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
train_state=train_state,
|
||||
optimizer=optimizer,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
grad_norm=grad_norm,
|
||||
)
|
||||
|
||||
timer("one-batch").stop()
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
|
||||
save_checkpoint(
|
||||
folder=save_ckpt_folder,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=lr_scheduler,
|
||||
train_state=train_state,
|
||||
model_config=gpc.config.model,
|
||||
)
|
||||
|
||||
# wait for all checkpoint uploads to be completed
|
||||
dist.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}")
|
||||
traceback.print_exc()
|
|
@ -0,0 +1 @@
|
|||
0.1.0
|
|
@ -0,0 +1,253 @@
|
|||
"""
|
||||
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers. We mainly modified part of the code logic to adapt to the generation of our model.
|
||||
Please refer to these links below for more information:
|
||||
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
||||
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
|
||||
3. transformers: https://github.com/huggingface/transformers
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Callable, Optional
|
||||
import copy
|
||||
import warnings
|
||||
import logging
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_interactive(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
additional_eos_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
|
||||
input_length = len(inputs["input_ids"][0])
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
input_ids = inputs["input_ids"]
|
||||
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
||||
if generation_config is None:
|
||||
generation_config = model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if additional_eos_token_id is not None:
|
||||
eos_token_id.append(additional_eos_token_id)
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||||
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warn(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if generation_config.do_sample:
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = model._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False
|
||||
)
|
||||
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
output_token_ids = input_ids[0].cpu().tolist()
|
||||
output_token_ids = output_token_ids[input_length:]
|
||||
for each_eos_token_id in eos_token_id:
|
||||
if output_token_ids[-1] == each_eos_token_id:
|
||||
output_token_ids = output_token_ids[:-1]
|
||||
response = tokenizer.decode(output_token_ids)
|
||||
|
||||
yield response
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
|
||||
def on_btn_click():
|
||||
del st.session_state.messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
max_length: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
do_sample: Optional[bool] = True
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_generation_config():
|
||||
with st.sidebar:
|
||||
max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048)
|
||||
top_p = st.slider(
|
||||
'Top P', 0.0, 1.0, 0.8, step=0.01
|
||||
)
|
||||
temperature = st.slider(
|
||||
'Temperature', 0.0, 1.0, 0.7, step=0.01
|
||||
)
|
||||
st.button("Clear Chat History", on_click=on_btn_click)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_length=max_length,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
user_prompt = "<|User|>:{user}<eoh>\n"
|
||||
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
|
||||
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
|
||||
|
||||
|
||||
def combine_history(prompt):
|
||||
messages = st.session_state.messages
|
||||
total_prompt = ""
|
||||
for message in messages:
|
||||
cur_content = message["content"]
|
||||
if message["role"] == "user":
|
||||
cur_prompt = user_prompt.replace("{user}", cur_content)
|
||||
elif message["role"] == "robot":
|
||||
cur_prompt = robot_prompt.replace("{robot}", cur_content)
|
||||
else:
|
||||
raise RuntimeError
|
||||
total_prompt += cur_prompt
|
||||
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
|
||||
return total_prompt
|
||||
|
||||
|
||||
def main():
|
||||
torch.cuda.empty_cache()
|
||||
print("load model begin.")
|
||||
model, tokenizer = load_model()
|
||||
print("load model end.")
|
||||
|
||||
user_avator = "doc/imgs/user.png"
|
||||
robot_avator = "doc/imgs/robot.png"
|
||||
|
||||
st.title("InternLM-Chat-7B")
|
||||
|
||||
generation_config = prepare_generation_config()
|
||||
|
||||
# Initialize chat history
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"], avatar=message.get("avatar")):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("What is up?"):
|
||||
# Display user message in chat message container
|
||||
with st.chat_message("user", avatar=user_avator):
|
||||
st.markdown(prompt)
|
||||
real_prompt = combine_history(prompt)
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
||||
|
||||
print(f"cur real input:\n{real_prompt}\n")
|
||||
|
||||
with st.chat_message("robot", avatar=robot_avator):
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(model=model, tokenizer=tokenizer, prompt=real_prompt, additional_eos_token_id=103028, **asdict(generation_config)):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(cur_response + "▌")
|
||||
message_placeholder.markdown(cur_response)
|
||||
print(f"cur total response:\n{cur_response}\n")
|
||||
# Add robot response to chat history
|
||||
st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|