Merge branch 'main' into fix/format

pull/2738/head
binmakeswell 2023-02-15 20:23:51 +08:00
commit 93b788b95a
147 changed files with 5014 additions and 6684 deletions

View File

@ -96,6 +96,7 @@ jobs:
- name: Store TensorNVMe Cache
run: |
cd TensorNVMe
cp -p -r ./build /github/home/tensornvme_cache/
- name: Checkout Colossal-AI

View File

@ -0,0 +1,28 @@
name: Build Documentation upon Release
on:
workflow_dispatch:
pull_request:
paths:
- 'version.txt'
- 'docs/'
types:
- closed
jobs:
build-doc:
name: Trigger Documentation Build Workflow
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
steps:
- name: trigger workflow in ColossalAI-Documentation
run: |
curl \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${GH_TOKEN}"\
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \
-d '{"ref":"main"}'
env:
GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}}

View File

@ -0,0 +1,42 @@
name: Run ChatGPT examples
on:
pull_request:
types: [synchronize, opened, reopened]
paths:
- 'applications/ChatGPT/chatgpt/**'
- 'applications/ChatGPT/requirements.txt'
- 'applications/ChatGPT/setup.py'
- 'applications/ChatGPT/examples/**'
jobs:
tests:
name: Run ChatGPT examples
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- name: Install ColossalAI and ChatGPT
run: |
pip install -v .
cd applications/ChatGPT
pip install -v .
pip install -r examples/requirements.txt
- name: Execute Examples
run: |
cd applications/ChatGPT
./examples/test_ci.sh
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
PROMPT_PATH: /data/scratch/chatgpt/prompts.csv

View File

@ -0,0 +1,42 @@
name: Run ChatGPT unit tests
on:
pull_request:
types: [synchronize, opened, reopened]
paths:
- 'applications/ChatGPT/chatgpt/**'
- 'applications/ChatGPT/requirements.txt'
- 'applications/ChatGPT/setup.py'
- 'applications/ChatGPT/requirements-test.txt'
- 'applications/ChatGPT/tests/**'
- 'applications/ChatGPT/pytest.ini'
jobs:
tests:
name: Run ChatGPT unit tests
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- name: Install ColossalAI and ChatGPT
run: |
pip install -v .
cd applications/ChatGPT
pip install -v .
pip install -r requirements-test.txt
- name: Execute Unit Testing
run: |
cd applications/ChatGPT
pytest tests/
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8

View File

@ -292,7 +292,13 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s
y = []
if len(total_engagement_count) > 0:
ranking = []
for name, count in total_engagement_count.items():
ranking.append((name, count))
ranking.sort(key=lambda x: x[1], reverse=True)
for name, count in ranking:
x.append(count)
y.append(name)

View File

@ -3,7 +3,7 @@
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/)
Colossal-AI: 一个面向大模型时代的通用深度学习系统
Colossal-AI: 让AI大模型更低成本、方便易用、高效扩展
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
<a href="https://www.colossalai.org/"> 文档 </a> |
@ -23,10 +23,10 @@
</div>
## 新闻
* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
@ -64,6 +64,7 @@
<li>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
<ul>
<li><a href="#ChatGPT">ChatGPT: 低成本复现ChatGPT完整流程</a></li>
<li><a href="#AIGC">AIGC: 加速 Stable Diffusion</a></li>
<li><a href="#生物医药">生物医药: 加速AlphaFold蛋白质结构预测</a></li>
</ul>
@ -102,7 +103,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行
- [序列并行](https://arxiv.org/abs/2105.13120)
- [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054)
- [自动并行](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
- [自动并行](https://arxiv.org/abs/2302.02599)
- 异构内存管理
- [PatrickStar](https://arxiv.org/abs/2108.05818)
- 使用友好
@ -209,6 +210,29 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
<p align="right">(<a href="#top">返回顶端</a>)</p>
## Colossal-AI 成功案例
### ChatGPT
低成本复现[ChatGPT](https://openai.com/blog/chatgpt/)完整流程 [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[博客]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
<p id="ChatGPT_scaling" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
</p>
- 最高可提升单机训练速度7.73倍单卡推理速度1.42倍
<p id="ChatGPT-1GPU" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
</p>
- 单卡模型容量最多提升10.3倍
- 最小demo训练流程最低仅需1.62GB显存 (任意消费级GPU)
<p id="inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
</p>
- 提升单卡的微调模型容量3.7倍
- 同时保持高速运行
<p align="right">(<a href="#top">back to top</a>)</p>
### AIGC
加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion)

View File

@ -3,7 +3,7 @@
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/)
Colossal-AI: A Unified Deep Learning System for Big Model Era
Colossal-AI: Making big AI models cheaper, easier, and scalable
<h3> <a href="https://arxiv.org/abs/2110.14883"> Paper </a> |
<a href="https://www.colossalai.org/"> Documentation </a> |
@ -24,10 +24,10 @@
</div>
## Latest News
* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## Table of Contents
@ -64,6 +64,7 @@
<li>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI for Real World Applications</a>
<ul>
<li><a href="#ChatGPT">ChatGPT: Low-cost ChatGPT Equivalent Implementation Process</a></li>
<li><a href="#AIGC">AIGC: Acceleration of Stable Diffusion</a></li>
<li><a href="#Biomedicine">Biomedicine: Acceleration of AlphaFold Protein Structure</a></li>
</ul>
@ -104,7 +105,7 @@ distributed training and inference in a few lines.
- 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism
- [Sequence Parallelism](https://arxiv.org/abs/2105.13120)
- [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054)
- [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
- [Auto-Parallelism](https://arxiv.org/abs/2302.02599)
- Heterogeneous Memory Management
- [PatrickStar](https://arxiv.org/abs/2108.05818)
@ -211,6 +212,30 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
<p align="right">(<a href="#top">back to top</a>)</p>
## Colossal-AI in the Real World
### ChatGPT
A low-cost [ChatGPT](https://openai.com/blog/chatgpt/) equivalent implementation process. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[blog]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
<p id="ChatGPT_scaling" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
</p>
- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
<p id="ChatGPT-1GPU" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
</p>
- Up to 10.3x growth in model capacity on one GPU
- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
<p id="inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
</p>
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
- Keep in a sufficiently high running speed
<p align="right">(<a href="#top">back to top</a>)</p>
### AIGC
Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion).

146
applications/ChatGPT/.gitignore vendored Normal file
View File

@ -0,0 +1,146 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/.build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# IDE
.idea/
.vscode/
# macos
*.DS_Store
#data/
docs/.build
# pytorch checkpoint
*.pt
# ignore version.py generated by setup.py
colossalai/version.py

View File

@ -0,0 +1,202 @@
Copyright 2021- HPC-AI Technology Inc. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2021- HPC-AI Technology Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,114 @@
# RLHF - Colossal-AI
Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by Colossal-AI. It supports distributed training and offloading, which can fit extremly large models. More details can be found in the [blog](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt).
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/chatgpt.png" width=700/>
</p>
## Training process (step 3)
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/experience.jpg" width=500/>
</p>
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/train.jpg" width=500/>
</p>
## Install
```shell
pip install .
```
## Usage
The main entrypoint is `Trainer`. We only support PPO trainer now. We support many training strategies:
- NaiveStrategy: simplest strategy. Train on single GPU.
- DDPStrategy: use `torch.nn.parallel.DistributedDataParallel`. Train on multi GPUs.
- ColossalAIStrategy: use Gemini and Zero of ColossalAI. It eliminates model duplication on each GPU and supports offload. It's very useful when training large models on multi GPUs.
Simplest usage:
```python
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy
strategy = ColossalAIStrategy()
with strategy.model_init_context():
# init your model here
actor = Actor()
critic = Critic()
trainer = PPOTrainer(actor = actor, critic= critic, strategy, ...)
trainer.fit(dataset, ...)
```
For more details, see `examples/`.
We also support training reward model with true-world data. See `examples/train_reward_model.py`.
## Todo
- [x] implement PPO training
- [x] implement training reward model
- [x] support LoRA
- [ ] implement PPO-ptx fine-tuning
- [ ] integrate with Ray
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL)
## Invitation to open-source contribution
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build an ecosystem with Colossal-AI, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
You may contact us or participate in the following ways:
1. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) or submitting a [PR](https://github.com/hpcaitech/ColossalAI/pulls) on GitHub
2. Join the Colossal-AI community on
[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
and [WeChat](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
3. Check out and fill in the [cooperation proposal](https://www.hpc-ai.tech/partners)
4. Send your proposal to email contact@hpcaitech.com
Thanks so much to all of our amazing contributors!
## Quick Preview
<p id="ChatGPT_scaling" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
</p>
- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
<p id="ChatGPT-1GPU" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
</p>
- Up to 10.3x growth in model capacity on one GPU
- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
<p id="inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
</p>
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
- Keep in a sufficiently high running speed
## Citations
```bibtex
@article{Hu2021LoRALA,
title = {LoRA: Low-Rank Adaptation of Large Language Models},
author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
journal = {ArXiv},
year = {2021},
volume = {abs/2106.09685}
}
@article{ouyang2022training,
title={Training language models to follow instructions with human feedback},
author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others},
journal={arXiv preprint arXiv:2203.02155},
year={2022}
}
```

View File

@ -0,0 +1,94 @@
# Benchmarks
## Benchmark GPT on dummy prompt data
We provide various GPT models (string in parentheses is the corresponding model name used in this script):
- GPT2-S (s)
- GPT2-M (m)
- GPT2-L (l)
- GPT2-XL (xl)
- GPT2-4B (4b)
- GPT2-6B (6b)
- GPT2-8B (8b)
- GPT2-10B (10b)
- GPT2-12B (12b)
- GPT2-15B (15b)
- GPT2-18B (18b)
- GPT2-20B (20b)
- GPT2-24B (24b)
- GPT2-28B (28b)
- GPT2-32B (32b)
- GPT2-36B (36b)
- GPT2-40B (40b)
- GPT3 (175b)
We also provide various training strategies:
- ddp: torch DDP
- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3
- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload
- colossalai_zero2: ColossalAI zero2
- colossalai_zero2_cpu: ColossalAI zero2-offload
- colossalai_zero1: ColossalAI zero1
- colossalai_zero1_cpu: ColossalAI zero1-offload
We only support `torchrun` to launch now. E.g.
```shell
# run GPT2-S on single-node single-GPU with min batch size
torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1
# run GPT2-XL on single-node 4-GPU
torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2
# run GPT3 on 8-node 8-GPU
torchrun --nnodes 8 --nproc_per_node 8 \
--rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \
benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini
```
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script.
Usage:
```shell
# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
./benchmark_gpt_dummy.sh
# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
./benchmark_gpt_dummy.sh 2
# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
./benchmark_gpt_dummy.sh 2 ddp
# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256)
./benchmark_gpt_dummy.sh 2 ddp l
```
## Benchmark OPT with LoRA on dummy prompt data
We provide various OPT models (string in parentheses is the corresponding model name used in this script):
- OPT-125M (125m)
- OPT-350M (350m)
- OPT-700M (700m)
- OPT-1.3B (1.3b)
- OPT-2.7B (2.7b)
- OPT-3.5B (3.5b)
- OPT-5.5B (5.5b)
- OPT-6.7B (6.7b)
- OPT-10B (10b)
- OPT-13B (13b)
We only support `torchrun` to launch now. E.g.
```shell
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
# run OPT-350M with lora_rank=4 on single-node 4-GPU
torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4
```
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.

View File

@ -0,0 +1,180 @@
import argparse
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
from torch.optim import Adam
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters())
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
numel *= dist.get_world_size()
return numel
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask}
def print_rank_0(*args, **kwargs) -> None:
if dist.get_rank() == 0:
print(*args, **kwargs)
def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = ''
for name, numel in model_dict.items():
outputs += f'{name}: '
if numel >= B:
outputs += f'{numel / B:.2f} B\n'
elif numel >= M:
outputs += f'{numel / M:.2f} M\n'
elif numel >= K:
outputs += f'{numel / K:.2f} K\n'
else:
outputs += f'{numel}\n'
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> GPT2Config:
model_map = {
's': GPT2Config(),
'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
'2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
'4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
'6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
'8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
'10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
'12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
'15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
'18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
'20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
'24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
'28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
'32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
'36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
'40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
'175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
}
try:
return model_map[model_name]
except KeyError:
raise ValueError(f'Unknown model "{model_name}"')
def main(args):
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_gemini_cpu':
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
elif args.strategy == 'colossalai_zero1':
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
elif args.strategy == 'colossalai_zero1_cpu':
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
model_config = get_gpt_config(args.model)
with strategy.model_init_context():
actor = GPTActor(config=model_config).cuda()
critic = GPTCritic(config=model_config).cuda()
initial_model = deepcopy(actor).cuda()
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
print_model_numel({
'Actor': actor_numel,
'Critic': critic_numel,
'Initial model': initial_model_numel,
'Reward model': reward_model_numel
})
performance_evaluator = PerformanceEvaluator(actor_numel,
critic_numel,
initial_model_numel,
reward_model_numel,
enable_grad_checkpoint=False,
ignore_episodes=1)
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=preprocess_batch,
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='s')
parser.add_argument('--strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,45 @@
#!/usr/bin/env bash
# Usage: $0 <?number-of-gpus> <?strategy> <?model>
set -xu
BASE=$(realpath $(dirname $0))
PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py
export OMP_NUM_THREADS=8
function tune_batch_size() {
# we found when experience batch size is equal to train batch size
# peak CUDA memory usage of making experience phase is less than or equal to that of training phase
# thus, experience batch size can be larger than or equal to train batch size
for bs in 1 2 4 8 16 32 64 128 256; do
torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1
done
}
if [ $# -eq 0 ]; then
num_gpus=(1 2 4 8)
else
num_gpus=($1)
fi
if [ $# -le 1 ]; then
strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu")
else
strategies=($2)
fi
if [ $# -le 2 ]; then
models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b")
else
models=($3)
fi
for num_gpu in ${num_gpus[@]}; do
for strategy in ${strategies[@]}; do
for model in ${models[@]}; do
tune_batch_size $num_gpu $model $strategy || break
done
done
done

View File

@ -0,0 +1,175 @@
import argparse
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
from torch.optim import Adam
from transformers import AutoTokenizer
from transformers.models.opt.configuration_opt import OPTConfig
from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters())
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
numel *= dist.get_world_size()
return numel
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask}
def print_rank_0(*args, **kwargs) -> None:
if dist.get_rank() == 0:
print(*args, **kwargs)
def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = ''
for name, numel in model_dict.items():
outputs += f'{name}: '
if numel >= B:
outputs += f'{numel / B:.2f} B\n'
elif numel >= M:
outputs += f'{numel / M:.2f} M\n'
elif numel >= K:
outputs += f'{numel / K:.2f} K\n'
else:
outputs += f'{numel}\n'
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
}
try:
return model_map[model_name]
except KeyError:
raise ValueError(f'Unknown model "{model_name}"')
def main(args):
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_gemini_cpu':
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
elif args.strategy == 'colossalai_zero1':
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
elif args.strategy == 'colossalai_zero1_cpu':
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
model_config = get_gpt_config(args.model)
with strategy.model_init_context():
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
initial_model = deepcopy(actor).cuda()
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
print_model_numel({
'Actor': actor_numel,
'Critic': critic_numel,
'Initial model': initial_model_numel,
'Reward model': reward_model_numel
})
performance_evaluator = PerformanceEvaluator(actor_numel,
critic_numel,
initial_model_numel,
reward_model_numel,
enable_grad_checkpoint=False,
ignore_episodes=1)
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
tokenizer.pad_token = tokenizer.eos_token
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=preprocess_batch,
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='125m')
parser.add_argument('--strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=4)
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
args = parser.parse_args()
main(args)

View File

View File

@ -0,0 +1,3 @@
from .reward_dataset import RewardDataset
__all__ = ['RewardDataset']

View File

@ -0,0 +1,52 @@
from typing import Callable
from torch.utils.data import Dataset
from tqdm import tqdm
class RewardDataset(Dataset):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
super().__init__()
self.chosen = []
self.reject = []
for data in tqdm(dataset):
prompt = data['prompt']
chosen = prompt + data['chosen'] + "<|endoftext|>"
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
reject = prompt + data['rejected'] + "<|endoftext|>"
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]

View File

@ -0,0 +1,4 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']

View File

@ -0,0 +1,77 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from chatgpt.nn.actor import Actor
@dataclass
class Experience:
"""Experience is a batch of data.
These data should have the the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
sequences: (B, S)
action_log_probs: (B, A)
values: (B)
reward: (B)
advatanges: (B)
attention_mask: (B, S)
action_mask: (B, A)
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
@torch.no_grad()
def to_device(self, device: torch.device) -> None:
self.sequences = self.sequences.to(device)
self.action_log_probs = self.action_log_probs.to(device)
self.values = self.values.to(device)
self.reward = self.reward.to(device)
self.advantages = self.advantages.to(device)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device)
if self.action_mask is not None:
self.action_mask = self.action_mask.to(device)
def pin_memory(self):
self.sequences = self.sequences.pin_memory()
self.action_log_probs = self.action_log_probs.pin_memory()
self.values = self.values.pin_memory()
self.reward = self.reward.pin_memory()
self.advantages = self.advantages.pin_memory()
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.pin_memory()
if self.action_mask is not None:
self.action_mask = self.action_mask.pin_memory()
return self
class ExperienceMaker(ABC):
def __init__(self,
actor: Actor,
critic: nn.Module,
reward_model: nn.Module,
initial_model: Actor,
kl_coef: float = 0.1) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
self.kl_coef = kl_coef
@abstractmethod
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
pass

View File

@ -0,0 +1,36 @@
import torch
from chatgpt.nn.utils import compute_reward, normalize
from .base import Experience, ExperienceMaker
class NaiveExperienceMaker(ExperienceMaker):
"""
Naive experience maker.
"""
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval()
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
return_action_mask=True,
**generate_kwargs)
num_actions = action_mask.size(1)
action_log_probs = self.actor(sequences, num_actions, attention_mask)
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
value = self.critic(sequences, action_mask, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
advantage = reward - value
# TODO(ver217): maybe normalize adv
if advantage.ndim == 1:
advantage = advantage.unsqueeze(-1)
return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)

View File

@ -0,0 +1,18 @@
from .actor import Actor
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
from .critic import Critic
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .reward_model import RewardModel
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor',
'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM'
]

View File

@ -0,0 +1,62 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .generation import generate
from .lora import LoRAModule
from .utils import log_probs_from_logits
class Actor(LoRAModule):
"""
Actor model base class.
Args:
model (nn.Module): Actor Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

View File

@ -0,0 +1,35 @@
from typing import Optional
import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .actor import Actor
class BLOOMActor(Actor):
"""
BLOOM Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .critic import Critic
class BLOOMCritic(Critic):
"""
BLOOM Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .reward_model import RewardModel
class BLOOMRM(RewardModel):
"""
BLOOM Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -0,0 +1,47 @@
from typing import Optional
import torch
import torch.nn as nn
from .lora import LoRAModule
from .utils import masked_mean
class Critic(LoRAModule):
"""
Critic model base class.
Args:
model (nn.Module): Critic model.
value_head (nn.Module): Value head to get value.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
self.convert_to_lora()
def forward(self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
if action_mask is not None:
num_actions = action_mask.size(1)
values = values[:, -num_actions:]
value = masked_mean(values, action_mask, dim=1)
return value
value = values.mean(dim=1).squeeze(1)
return value

View File

@ -0,0 +1,137 @@
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
processor_list.append(TopKLogitsWarper(top_k))
if top_p is not None and top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
return processor_list
def sample(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
'input_ids': input_ids
}
outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and unfinished_sequences.max() == 0:
break
return input_ids
def generate(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
model (nn.Module): model
input_ids (torch.Tensor): input sequence
max_length (int): max length of the returned sequence
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
elif is_beam_gen_mode:
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")

View File

@ -0,0 +1,92 @@
from typing import Optional
import torch
def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
return model_kwargs
def opt_prepare_inputs_fn(input_ids: torch.Tensor,
past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs) -> dict:
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs) -> dict:
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}

View File

@ -0,0 +1,31 @@
from typing import Optional
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from .actor import Actor
class GPTActor(Actor):
"""
GPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
model = GPT2LMHeadModel(config)
else:
model = GPT2LMHeadModel(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model)

View File

@ -0,0 +1,33 @@
from typing import Optional
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from .critic import Critic
class GPTCritic(Critic):
"""
GPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head)

View File

@ -0,0 +1,33 @@
from typing import Optional
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from .reward_model import RewardModel
class GPTRM(RewardModel):
"""
GPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head)

View File

@ -0,0 +1,127 @@
import math
from typing import Optional
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def __init__(
self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight
self.bias = bias
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
if self.lora_rank <= 0:
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)

View File

@ -0,0 +1,105 @@
from typing import Optional
import torch
import torch.nn as nn
from .utils import masked_mean
class GPTLMLoss(nn.Module):
"""
GPT Language Model Loss
"""
def __init__(self):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
loss = loss.mean()
return loss
class ValueLoss(nn.Module):
"""
Value Loss for PPO
"""
def __init__(self, clip_eps: float = 0.4) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2
surr2 = (values - reward)**2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class PairWiseLoss(nn.Module):
"""
Pairwise Loss for Reward Model
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward)
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss

View File

@ -0,0 +1,35 @@
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from .actor import Actor
class OPTActor(Actor):
"""
OPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from .critic import Critic
class OPTCritic(Critic):
"""
OPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -0,0 +1,33 @@
from typing import Optional
import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from .reward_model import RewardModel
class OPTRM(RewardModel):
"""
OPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -0,0 +1,41 @@
from typing import Optional
import torch
import torch.nn as nn
from .lora import LoRAModule
class RewardModel(LoRAModule):
"""
Reward model base class.
Args:
model (nn.Module): Reward model.
value_head (nn.Module): Value head to get reward score.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
model: nn.Module,
value_head: Optional[nn.Module] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
if value_head is not None:
if value_head.out_features != 1:
raise ValueError("The value head of reward model's output dim should be 1!")
self.value_head = value_head
else:
self.value_head = nn.Linear(model.config.n_embd, 1)
self.convert_to_lora()
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
return value

View File

@ -0,0 +1,92 @@
from typing import Optional, Union
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
action_mask: Mask for actions.
"""
log_ratio = log_probs - log_probs_base
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
return approx_kl
approx_kl = approx_kl.mean(dim=1)
return approx_kl
def compute_reward(r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
tensor = tensor * mask
mean = masked_mean(tensor, mask, dim=dim)
mean_centered = tensor - mean
var = masked_mean(mean_centered**2, mask, dim=dim)
return mean_centered * var.clamp(min=eps).rsqrt()
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
mean = tensor.mean(dim)
mean_centered = tensor - mean
var = (mean_centered**2).mean(dim)
norm = mean_centered * var.clamp(min=eps).rsqrt()
return norm
def convert_to_lora(model: nn.Module,
input_size: int,
output_size: int,
lora_rank: int = 16,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
merge_weights: bool = True):
if lora_rank > min(input_size, output_size):
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
module._modules[name] = lora.Linear(input_size,
output_size,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
fan_in_fan_out=fan_in_fan_out,
merge_weights=merge_weights)

View File

@ -0,0 +1,4 @@
from .base import ReplayBuffer
from .naive import NaiveReplayBuffer
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']

View File

@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Any
from chatgpt.experience_maker.base import Experience
class ReplayBuffer(ABC):
"""Replay buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
super().__init__()
self.sample_batch_size = sample_batch_size
# limit <= 0 means unlimited
self.limit = limit
@abstractmethod
def append(self, experience: Experience) -> None:
pass
@abstractmethod
def clear(self) -> None:
pass
@abstractmethod
def sample(self) -> Experience:
pass
@abstractmethod
def __len__(self) -> int:
pass
@abstractmethod
def __getitem__(self, idx: int) -> Any:
pass
@abstractmethod
def collate_fn(self, batch: Any) -> Experience:
pass

View File

@ -0,0 +1,57 @@
import random
from typing import List
import torch
from chatgpt.experience_maker.base import Experience
from .base import ReplayBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer):
"""Naive replay buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
self.items.clear()
@torch.no_grad()
def sample(self) -> Experience:
items = random.sample(self.items, self.sample_batch_size)
experience = make_experience_batch(items)
if self.cpu_offload:
experience.to_device(self.target_device)
return experience
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> BufferItem:
return self.items[idx]
def collate_fn(self, batch) -> Experience:
experience = make_experience_batch(batch)
return experience

View File

@ -0,0 +1,73 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F
from chatgpt.experience_maker.base import Experience
@dataclass
class BufferItem:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
advatanges: (1)
attention_mask: (S)
action_mask: (A)
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
vals = torch.unbind(value)
else:
# None
vals = [value for _ in range(batch_size)]
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
batch_data = zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
return Experience(**kwargs)

View File

@ -0,0 +1,5 @@
from .base import Trainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer']

View File

@ -0,0 +1,162 @@
import random
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from chatgpt.experience_maker import Experience, ExperienceMaker
from chatgpt.replay_buffer import ReplayBuffer
from torch import Tensor
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
class Trainer(ABC):
"""
Base class for rlhf trainers.
Args:
strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
experience_maker: ExperienceMaker,
replay_buffer: ReplayBuffer,
experience_batch_size: int = 8,
max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
super().__init__()
self.strategy = strategy
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs
self.sample_replay_buffer = sample_replay_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = random.sample(indices, self.experience_batch_size)
return [prompts[i] for i in sampled_indices]
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
time = 0
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
rand_prompts = self._sample_prompts(prompts)
if self.tokenizer is not None:
inputs = self.tokenizer(rand_prompts)
else:
inputs = rand_prompts
self._on_make_experience_start()
experience = self._make_experience(inputs)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()
def _on_episode_start(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_learn_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_start(epoch)
def _on_learn_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_learn_epoch_end(epoch)
def _on_learn_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)

View File

@ -0,0 +1,4 @@
from .base import Callback
from .performance_evaluator import PerformanceEvaluator
__all__ = ['Callback', 'PerformanceEvaluator']

View File

@ -0,0 +1,39 @@
from abc import ABC
from chatgpt.experience_maker import Experience
class Callback(ABC):
"""
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
pass
def on_fit_end(self) -> None:
pass
def on_episode_start(self, episode: int) -> None:
pass
def on_episode_end(self, episode: int) -> None:
pass
def on_make_experience_start(self) -> None:
pass
def on_make_experience_end(self, experience: Experience) -> None:
pass
def on_learn_epoch_start(self, epoch: int) -> None:
pass
def on_learn_epoch_end(self, epoch: int) -> None:
pass
def on_learn_batch_start(self) -> None:
pass
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
pass

View File

@ -0,0 +1,133 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
from chatgpt.experience_maker import Experience
from .base import Callback
def get_world_size() -> int:
if dist.is_initialized():
return dist.get_world_size()
return 1
def print_rank_0(*args, **kwargs) -> None:
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class PerformanceEvaluator(Callback):
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
self.critic_num_params = critic_num_params
self.initial_model_num_params = initial_model_num_params
self.reward_model_num_params = reward_model_num_params
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_episodes = ignore_episodes
self.disable: bool = False
self.make_experience_duration: float = 0.
self.make_experience_start_time: Optional[float] = None
self.make_experience_num_samples: int = 0
self.make_experience_flop: int = 0
self.learn_duration: float = 0.
self.learn_start_time: Optional[float] = None
self.learn_num_samples: int = 0
self.learn_flop: int = 0
def on_episode_start(self, episode: int) -> None:
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
def on_make_experience_start(self) -> None:
if self.disable:
return
self.make_experience_start_time = time()
def on_make_experience_end(self, experience: Experience) -> None:
if self.disable:
return
self.make_experience_duration += time() - self.make_experience_start_time
batch_size, seq_len = experience.sequences.shape
self.make_experience_num_samples += batch_size
# actor generate
num_actions = experience.action_mask.size(1)
input_len = seq_len - num_actions
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
# actor forward
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
# critic forward
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
# initial model forward
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
# reward model forward
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
def on_learn_batch_start(self) -> None:
if self.disable:
return
self.learn_start_time = time()
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
if self.disable:
return
self.learn_duration += time() - self.learn_start_time
batch_size, seq_len = experience.sequences.shape
self.learn_num_samples += batch_size
# actor forward-backward, 3 means forward(1) + backward(2)
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
# critic foward-backward
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
print_rank_0(
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
)
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')

View File

@ -0,0 +1,114 @@
from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
from chatgpt.nn.generation_utils import update_model_kwargs_fn
from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer
from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
class PPOTrainer(Trainer):
"""
Trainer for PPO algorithm.
Args:
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenier (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.4,
experience_batch_size: int = 8,
max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
self._set_default_generate_kwargs(generate_kwargs, actor)
actor = Actor(strategy.setup_model(actor.model))
critic = strategy.setup_model(critic)
reward_model = strategy.setup_model(reward_model)
initial_model = Actor(strategy.setup_model(initial_model.model))
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
self.actor = actor
self.critic = critic
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model)
self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic)
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

View File

@ -0,0 +1,77 @@
from abc import ABC
import loralib as lora
from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
class RewardModelTrainer(ABC):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
num_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(self,
model,
train_dataset: RewardDataset,
eval_dataset: RewardDataset,
batch_size: int = 1,
num_epochs: int = 2,
optim_kwargs: dict = {'lr': 1e-4}) -> None:
super().__init__()
self.model = model
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.loss_fn = PairWiseLoss()
self.optimizer = Adam(self.model.parameters(), **optim_kwargs)
self.epochs = num_epochs
def fit(self, use_lora):
epoch_bar = tqdm(range(self.epochs), desc='Train epoch')
for epoch in range(self.epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch)
# train
if use_lora > 0:
print("Using Lora")
lora.mark_only_lora_as_trainable(self.model)
else:
self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
step_bar.update()
step_bar.set_postfix({'loss': loss.item()})
# eval
self.model.eval()
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
dist = 0
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward)
dist_mean = dist / self.eval_dataloader.__len__()
epoch_bar.update()
step_bar.set_postfix({'loss': loss.item(), 'dist_mean': dist_mean.item()})
step_bar.close()

View File

@ -0,0 +1,6 @@
from .base import Strategy
from .colossalai import ColossalAIStrategy
from .ddp import DDPStrategy
from .naive import NaiveStrategy
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']

View File

@ -0,0 +1,45 @@
from abc import ABC, abstractmethod
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch.optim as optim
from chatgpt.replay_buffer import ReplayBuffer
from torch.utils.data import DataLoader
class Strategy(ABC):
"""
Base class for training strategies.
"""
def __init__(self) -> None:
super().__init__()
self.setup_distributed()
@abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
pass
@abstractmethod
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
pass
@abstractmethod
def setup_distributed(self) -> None:
pass
@abstractmethod
def setup_model(self, model: nn.Module) -> nn.Module:
pass
@abstractmethod
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
pass
@abstractmethod
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):
return nullcontext()

View File

@ -0,0 +1,125 @@
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from .ddp import DDPStrategy
class ColossalAIStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is cpu, parameters, gradients and optimizer states will be offloaded to CPU,
If it is cuda, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(
self,
stage: int = 3,
seed: int = 42,
shard_init: bool = True, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_mb: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_mb: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0) -> None:
super().__init__(seed)
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
self.stage = stage
self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_mb=search_range_mb,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb)
if stage == 3:
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
else:
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'))
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
if self.stage == 3:
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
return super().model_init_context()
def setup_model(self, model: nn.Module) -> nn.Module:
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.backward(loss)
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()

View File

@ -0,0 +1,59 @@
import os
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from .naive import NaiveStrategy
class DDPStrategy(NaiveStrategy):
"""
Strategy for distributed training using torch.distributed.
"""
def __init__(self, seed: int = 42) -> None:
self.seed = seed
super().__init__()
def setup_distributed(self) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
self.set_seed(self.seed)
torch.cuda.set_device(local_rank)
def set_seed(self, seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_model(self, model: nn.Module) -> nn.Module:
device = torch.cuda.current_device()
return DDP(model, device_ids=[device])
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
sampler = DistributedSampler(replay_buffer,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True,
seed=self.seed,
drop_last=True)
return DataLoader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
sampler=sampler,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)

View File

@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.optim as optim
from chatgpt.replay_buffer import ReplayBuffer
from torch.utils.data import DataLoader
from .base import Strategy
class NaiveStrategy(Strategy):
"""
Strategy for single GPU. No parallelism is used.
"""
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
loss.backward()
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
def setup_distributed(self) -> None:
pass
def setup_model(self, model: nn.Module) -> nn.Module:
return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
return optimizer
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
return DataLoader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)

View File

@ -0,0 +1,5 @@
import torch.distributed as dist
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0

View File

@ -0,0 +1,105 @@
# Examples
## Install requirements
```shell
pip install -r requirements.txt
```
## Train with dummy prompt data
This script supports 3 strategies:
- naive
- ddp
- colossalai
It uses random generated prompt data.
Naive strategy only support single GPU training:
```shell
python train_dummy.py --strategy naive
# display cli help
python train_dummy.py -h
```
DDP strategy and ColossalAI strategy support multi GPUs training:
```shell
# run DDP on 2 GPUs
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
# run ColossalAI on 2 GPUs
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai
```
## Train with real prompt data
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
You should download `prompts.csv` first.
This script also supports 3 strategies.
```shell
# display cli help
python train_dummy.py -h
# run naive on 1 GPU
python train_prompts.py prompts.csv --strategy naive
# run DDP on 2 GPUs
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp
# run ColossalAI on 2 GPUs
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai
```
## Train the reward model
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
You can download the dataset from huggingface automatically.
Use these code to train your reward model.
```shell
# Naive reward model training
python train_reward_model.py --pretrain <your model path>
# if to use LoRA
python train_reward_model.py --pretrain <your model path> --lora_rank 16
```
## Support Model
### GPT
- [ ] GPT2-S (s)
- [ ] GPT2-M (m)
- [ ] GPT2-L (l)
- [ ] GPT2-XL (xl)
- [ ] GPT2-4B (4b)
- [ ] GPT2-6B (6b)
- [ ] GPT2-8B (8b)
- [ ] GPT2-10B (10b)
- [ ] GPT2-12B (12b)
- [ ] GPT2-15B (15b)
- [ ] GPT2-18B (18b)
- [ ] GPT2-20B (20b)
- [ ] GPT2-24B (24b)
- [ ] GPT2-28B (28b)
- [ ] GPT2-32B (32b)
- [ ] GPT2-36B (36b)
- [ ] GPT2-40B (40b)
- [ ] GPT3 (175b)
### BLOOM
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
- [ ] BLOOM-175b
### OPT
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)

View File

@ -0,0 +1 @@
pandas>=1.4.1

View File

@ -0,0 +1,27 @@
#!/usr/bin/env bash
set -xue
if [ -z "$PROMPT_PATH" ]; then
echo "Please set \$PROMPT_PATH to the path to prompts csv."
exit 1
fi
BASE=$(realpath $(dirname $0))
export OMP_NUM_THREADS=8
# install requirements
pip install -r ${BASE}/requirements.txt
# train dummy
python ${BASE}/train_dummy.py --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
for strategy in ddp colossalai_gemini colossalai_zero2; do
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
done
# train prompts
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
for strategy in ddp colossalai_gemini colossalai_zero2; do
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
done

View File

@ -0,0 +1,112 @@
import argparse
from copy import deepcopy
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def preprocess_batch(samples):
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask}
def main(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'gpt2':
actor = GPTActor().cuda()
critic = GPTCritic().cuda()
elif args.model == 'bloom':
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'opt':
actor = OPTActor().cuda()
critic = OPTCritic().cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
initial_model = deepcopy(actor).cuda()
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
# configure optimizer
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure trainer
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=preprocess_batch,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--num_episodes', type=int, default=50)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,18 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16

View File

@ -0,0 +1,112 @@
import argparse
from copy import deepcopy
import pandas as pd
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'gpt2':
actor = GPTActor().cuda()
critic = GPTCritic().cuda()
elif args.model == 'bloom':
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'opt':
actor = OPTActor(lora_rank=args.lora_rank).cuda()
critic = OPTCritic(lora_rank=args.lora_rank).cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
# configure optimizer
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
# configure trainer
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=tokenize_fn,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
trainer.fit(dataset,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,18 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16

View File

@ -0,0 +1,53 @@
import argparse
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM
from chatgpt.trainer import RewardModelTrainer
from datasets import load_dataset
from transformers import BloomTokenizerFast
def train(args):
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
model = BLOOMRM(pretrained=args.pretrain)
model.cuda()
max_len = 1024
# prepare for data and dataset
data = load_dataset(args.dataset)
train_data = data["train"]
eval_data = data['test']
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
# batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer = RewardModelTrainer(model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
batch_size=args.batch_size,
num_epochs=args.max_epochs)
trainer.fit(use_lora=args.lora_rank)
if args.lora_rank > 0:
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
else:
torch.save(trainer.model, args.save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
train(args)

View File

@ -0,0 +1,18 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16

View File

@ -0,0 +1,6 @@
[pytest]
markers =
cpu: tests which can run on CPU
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features

View File

@ -0,0 +1 @@
pytest

View File

@ -0,0 +1,6 @@
transformers>=4.20.1
tqdm
datasets
loralib
colossalai>=0.2.4
torch

View File

@ -0,0 +1,41 @@
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, 'r') as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open('README.md', encoding='utf-8') as f:
return f.read()
def fetch_version():
with open('version.txt', 'r') as f:
return f.read().strip()
setup(
name='chatgpt',
version=fetch_version(),
packages=find_packages(exclude=(
'tests',
'benchmarks',
'*.egg-info',
)),
description='A RLFH implementation (ChatGPT) powered by ColossalAI',
long_description=fetch_readme(),
long_description_content_type='text/markdown',
license='Apache Software License 2.0',
url='https://github.com/hpcaitech/ChatGPT',
install_requires=fetch_requirements('requirements.txt'),
python_requires='>=3.6',
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
'Environment :: GPU :: NVIDIA CUDA',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: System :: Distributed Computing',
],
)

View File

View File

@ -0,0 +1,117 @@
import os
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from chatgpt.experience_maker import NaiveExperienceMaker
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.replay_buffer import NaiveReplayBuffer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def gather_and_equal(tensor: torch.Tensor) -> bool:
world_size = dist.get_world_size()
outputs = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(outputs, tensor.contiguous())
for t in outputs[1:]:
if not torch.equal(outputs[0], t):
return False
return True
def run_test_data(strategy):
EXPERINCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2
if strategy == 'ddp':
strategy = DDPStrategy()
elif strategy == 'colossalai':
strategy = ColossalAIStrategy(placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
actor = GPTActor().cuda()
critic = GPTCritic().cuda()
initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model)).cuda()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
# experience of all ranks should be the same
for _ in range(2):
data = get_data(EXPERINCE_BATCH_SIZE)
assert gather_and_equal(data['input_ids'])
assert gather_and_equal(data['attention_mask'])
experience = experience_maker.make_experience(**data,
do_sample=True,
max_length=16,
eos_token_id=50256,
pad_token_id=50256)
assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values)
assert gather_and_equal(experience.reward)
assert gather_and_equal(experience.advantages)
assert gather_and_equal(experience.action_mask)
assert gather_and_equal(experience.attention_mask)
replay_buffer.append(experience)
# replay buffer's data should be the same
buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
assert gather_and_equal(buffer_size)
for item in replay_buffer.items:
assert gather_and_equal(item.sequences)
assert gather_and_equal(item.action_log_probs)
assert gather_and_equal(item.values)
assert gather_and_equal(item.reward)
assert gather_and_equal(item.advantages)
assert gather_and_equal(item.action_mask)
assert gather_and_equal(item.attention_mask)
# dataloader of each rank should have the same size and different batch
dataloader = strategy.setup_dataloader(replay_buffer)
dataloader_size = torch.tensor([len(dataloader)], device='cuda')
assert gather_and_equal(dataloader_size)
for experience in dataloader:
assert not gather_and_equal(experience.sequences)
assert not gather_and_equal(experience.action_log_probs)
assert not gather_and_equal(experience.values)
assert not gather_and_equal(experience.reward)
assert not gather_and_equal(experience.advantages)
# action mask and attention mask may be same
def run_dist(rank, world_size, port, strategy):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
run_test_data(strategy)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
@rerun_if_address_is_in_use()
def test_data(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_data(2, 'colossalai')

View File

@ -0,0 +1 @@
0.1.0

View File

@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]

View File

@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager()
@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size],
return size
def _solution_annotatation(gm: torch.fx.GraphModule,
solution: List[int],
strategies_constructor: StrategiesConstructor = None):
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph = gm.graph
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if strategies_constructor is not None:
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
else:
nodes = tuple(mod_graph.nodes)
no_strategy_nodes = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
nodes = tuple(mod_graph.nodes)
node_pairs = {}
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
'''
A helper function to etract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
2. tensor.size(dim)
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
return target_dim
def _post_processing(node, size_processing_node):
'''
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
'''
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
step = slice_object.step
if start in node_pairs:
start = node_pairs[start]
if stop in node_pairs:
stop = node_pairs[stop]
if step in node_pairs:
step = node_pairs[step]
return slice(start, stop, step)
elif isinstance(slice_object, int):
if slice_object in node_pairs:
return node_pairs[slice_object]
else:
return slice_object
else:
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
sharding_spec = node.args[0].sharding_spec
dim_partition_dict = sharding_spec.dim_partition_dict
# there are two usages of torch.Tensor.size:
# tensor.size()
# tensor.size(dim)
# if a target_dim is assigned, then the output will be
# in type of int, instead of torch.Size
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
target_dim = _extract_target_dim(node)
# insert size_processing node
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
_post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem:
@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if isinstance(getitem_index, slice):
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
if getitem_index.start in node_pairs:
new_start = node_pairs[getitem_index.start]
elif getitem_index.stop in node_pairs:
new_stop = node_pairs[getitem_index.stop]
elif getitem_index.step in node_pairs:
new_step = node_pairs[getitem_index.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_slice_item = _update_slice_object_args(getitem_index)
new_args = (node.args[0], new_slice_item)
node.args = new_args
@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if slice_item is None:
new_slice_items.append(None)
continue
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs:
new_stop = node_pairs[slice_item.stop]
elif slice_item.step in node_pairs:
new_step = node_pairs[slice_item.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_slice_item = _update_slice_object_args(slice_item)
new_slice_items.append(new_slice_item)
new_args = (node.args[0], tuple(new_slice_items))
@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return gm
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
'''
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_extract_info_from_sharding_spec(element))
return dim_partition_dict, sharding_spec
def _process_node_arguments(node):
new_args = []
for arg in node.args:
# There are two args style:
# 1. (input, *shape)
# 2. (input, shape)
# We will extract the elements from shape and add them into the new_args
# Finally, the args style of new_args will be unified to (input, *shape)
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.extend(arg._meta_data)
elif isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg,
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
new_args.append(arg)
return new_args
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
if node.op == 'call_method':
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
for dim, shard_dims in dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
# we will skip the dim with -1 value
if args_to_process[dim] == -1:
continue
else:
# TODO: add assertion here to make sure the dim size is divisible by total_shard_size
args_to_process[dim] //= total_shard_size
args_to_process = tuple(args_to_process)
if node.op == 'call_method':
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
node.args = new_args
def _filter_node_with_shape_args(node):
if node.op == 'call_method':
target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function':
target = node.target
else:
target = None
if target in SHAPE_ARGUMENT_OPS:
return True
return False
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
def _process_sharding_spec(sharding_spec):
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_process_sharding_spec(element))
return dim_partition_dict, sharding_spec
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
new_args = []
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
# process the node with (input, *shape) style args
if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (int, tuple, list)):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if isinstance(new_args[1], int):
# we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
else:
new_args[dim + 1] //= total_shard_size
else:
new_args[1] = list(new_args[1])
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
else:
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
elif node.op == 'call_function':
target = node.target
# process the node with (input, torch.Size) style args
if target in (torch.reshape,):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.append(list(arg._meta_data))
else:
new_args.append(arg)
else:
assert isinstance(
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
new_args.append(list(arg))
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
if _filter_node_with_shape_args(node):
_scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node)
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
@ -361,6 +386,50 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param):
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
return True
return False
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
def _shard_param(param, target_sharding_spec):
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
return param
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
@ -370,35 +439,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
_add_hook_for_grad_communication(node, param)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
@ -426,36 +470,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
target = _shard_param(target, target_sharding_spec)
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
_add_hook_for_grad_communication(node, target)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
return gm
@ -469,14 +489,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict

View File

@ -1,6 +0,0 @@
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .solver import Solver
from .strategies_constructor import StrategiesConstructor

View File

@ -1,142 +0,0 @@
import functools
import operator
import warnings
from functools import reduce
from typing import Dict, List, Optional, Union
import torch
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INFINITY_COST
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
elif isinstance(input_, torch.Tensor):
shape = input_.shape
else:
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
index=None):
'''
Compute the resharding costs with this specific strategy.
Argument:
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
warnings.warn(f'{e}')
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def ignore_sharding_exception(func):
"""
A function wrapper which executes the function with a specified seed.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
return wrapper
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size

View File

@ -1,83 +0,0 @@
import torch
import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
torch.nn.functional.softmax
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,
torch.Tensor.split,
torch.Tensor.permute,
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d
]
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
torch.where,
operator.pow,
torch.pow,
torch.tanh,
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
# softmax should not be here
torch.nn.functional.softmax
]
INFINITY_COST = 1e13

View File

@ -1,174 +0,0 @@
import math
from typing import List
from torch.fx.node import Node
from .constants import INFINITY_COST
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
for strategies_vector in self.leaf_strategies:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict

View File

@ -1,165 +0,0 @@
from collections import OrderedDict as ODict
from dataclasses import dataclass
from typing import Any, List, OrderedDict, Union
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
@dataclass
class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
class LiveVariableVector(list):
"""
LiveVariableVector is a data structure to store the list of LiveVariable objects.
"""
def exists(self, name) -> bool:
"""
Check if a variable has already existed in the current list by name.
"""
for var in self:
if name == var.name:
return True
return False
def get(self, name) -> LiveVariable:
for var in self:
if name == var.name:
return var
raise KeyError(f"Variable {name} is not found")
def copy(self) -> "LiveVariableVector":
"""
Create a copy of this vector
"""
vector = LiveVariableVector()
for var in self:
vector.append(var)
return vector
@dataclass
class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
unique_live_vars: LiveVariableVector
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@property
def gm(self) -> GraphModule:
"""
Return the GraphModule object associated with this analyser.
"""
return self._gm
@property
def graph(self) -> Graph:
"""
Return the Graph object associated with this analyser.
"""
return self._graph
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_list = []
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector()
for idx, node in enumerate(compute_nodes):
#############################
# find new living variables #
#############################
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var
is_inplace = False
if node.op == 'call_function':
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
is_inplace = True
elif node.op == 'call_module':
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
checked_variables.append(live_var)
all_live_variables.append(live_var)
# check if any input is not checked yet
for arg in node.args:
if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name):
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg)
# TODO: add the logic to remove live variables
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)
return liveness_list
def get_alias_set(self):
pass

View File

@ -1,15 +0,0 @@
from .batch_norm_handler import BatchNormHandler
from .bcast_op_handler import BcastOpHandler
from .conv_handler import ConvHandler
from .dot_handler import DotHandler
from .embedding_handler import EmbeddingHandler
from .layer_norm_handler import LayerNormHandler
from .operator_handler import OperatorHandler
from .reshape_handler import ReshapeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler
__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler'
]

View File

@ -1,492 +0,0 @@
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['BatchNormHandler']
class BatchNormHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of normalization.
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
us to keep the computing correctness.
In this handler, both methods will be considered.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.bias = self.module_named_parameters['bias']
self.output_data = self.node._meta_data
self._sanity_check()
def _sanity_check(self):
'''
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
assert self.input_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def _generate_compute_cost(self, bs, channel_in):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
forward_compute_cost = input_size_product * bs * channel_in
backward_activation_compute_cost = input_size_product * bs * channel_in
backward_weight_compute_cost = input_size_product * bs * channel_in
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
compute_cost = forward_compute_cost + backward_compute_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = numel_output
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
# the memory cost need to be recomputed
# compute the memroy cost of new strategy
new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# the communication cost need to count the sharding cost into this strategy
# compute the communication cost of new strategy
origin_communication_cost = communication_cost
tiny_shard_cost = 10
new_forward_communication_cost = tiny_shard_cost
# we need to all gather the batch dimension for the basic strategy
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
sharding_strategies = ShardingStrategy(new_name,
output_sharding_spec=new_sharding_spec_for_output,
compute_cost=new_compute_cost,
communication_cost=new_communication_cost,
memory_cost=new_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
self.device_mesh.shape[mesh_dim_1])
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
# the memory cost need to be recomputed
new_sharding_size_input = 1
for mesh_dim in mesh_dim_list:
new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
# the communication cost need to count the sharding cost into this strategy
origin_communication_cost = communication_cost
tiny_shard_cost = 10
new_forward_communication_cost = tiny_shard_cost
if len(mesh_dim_list) == 1:
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
mesh_dim_list[0])
else:
new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
memory_cost_backward_activation, 0)
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
new_sharding_strategy = ShardingStrategy(new_name,
output_sharding_spec=new_sharding_spec_for_output,
compute_cost=new_compute_cost,
communication_cost=new_communication_cost,
memory_cost=new_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input,
sharding_spec_for_weight))
return new_sharding_strategy
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
# shard on mesh_dim_0
new_name = f'S{mesh_dim_0}R = RR x R'
mesh_dim_list = [mesh_dim_0]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
# shard on mesh_dim_1
new_name = f'S{mesh_dim_1}R = RR x R'
mesh_dim_list = [mesh_dim_1]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
# shard on mesh_dim_0, mesh_dim_1
new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
mesh_dim_list = [mesh_dim_0, mesh_dim_1]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
'''
# RS = RS x S and strategies based on it, such as
# SS = RS x S
self.split_input_channel(0, 1)
self.split_input_channel(1, 0)
# RR = RR x R and strategies based on it, such as
# SR = SR x R
self.non_split(0, 1)
# RS01 = RS01 x S01
self.split_input_channel_1d(0, 1)
# SR = SR x R WITH SYNC_BN
self.split_input_batch(0)
self.split_input_batch(1)
# SS = SS x S WITH SYNC_BN
self.split_input_both_dim(0, 1)
self.split_input_both_dim(1, 0)
# S01R = S01R x R WITH SYNC_BN
self.split_input_batch_1d(0, 1)
return self.strategies_vector

View File

@ -1,552 +0,0 @@
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['BcastOpHandler']
class BcastOpHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(self.predecessor_node) == 2
self.lhs_data = self.predecessor_node[0]._meta_data
self.rhs_data = self.predecessor_node[1]._meta_data
self.lhs = self.predecessor_node[0]
self.rhs = self.predecessor_node[1]
self.output_data = self.node._meta_data
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
shape = list(input_.shape)
# padding the shape to the same length as output_data
while len(shape) < self.output_data.dim():
shape.insert(0, 1)
shape = torch.Size(shape)
# if the sharding happens on a size one dimension, we should record it as R.
processed_dim_partition_dict = deepcopy(dim_partition_dict)
for dim_index, _ in dim_partition_dict.items():
if shape[dim_index] == 1:
processed_dim_partition_dict.pop(dim_index)
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=shape,
dim_partition_dict=processed_dim_partition_dict)
return sharding_spec
def _generate_compute_cost(self, total_sharding_size):
lhs_matrix_shape = self.lhs_data.shape[-2:]
rhs_matrix_shape = self.rhs_data.shape[-2:]
batch_dimensions_shape = self.output_data.shape[:-2]
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
compute_cost = reduce(
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
return compute_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
nodes = self.predecessor_node
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
# Then, use the padded input sharding spec to compute the resharding cost.
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
new_entire_shape = list(input_sharding_spec.entire_shape)
while len(new_entire_shape) < len(input_spec.entire_shape):
new_entire_shape.insert(0, 1)
new_entire_shape = torch.Size(new_entire_shape)
new_device_mesh = input_sharding_spec.device_mesh
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
entire_shape=new_entire_shape,
dim_partition_dict=new_dim_partition_dict)
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
sharding_spec_list = []
check_duplicated_list = []
for output_dim_partition_dict in dim_partition_list:
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
break
sharding_seq = output_sharding_spec.sharding_sequence
if sharding_seq not in check_duplicated_list:
check_duplicated_list.append(sharding_seq)
sharding_spec_list.append(output_sharding_spec)
return sharding_spec_list
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list = []
dim_size = self.output_data.dim()
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
return output_sharding_spec_list
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the computation cost of this strategy
sharding_dims = []
for mesh_dims in dim_partition_dict_for_output.values():
for mesh_dim in mesh_dims:
sharding_dims.append(self.device_mesh.shape[mesh_dim])
sharding_size = reduce(operator.mul, sharding_dims, 1)
memory_cost = self.output_data.numel() / sharding_size
compute_cost = memory_cost
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
##############################################
#used to generate strategies for torch.matmul#
##############################################
@ignore_sharding_exception
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
# this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process.
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
batch_sharding_dims = []
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
# in this case, total_sharding_size is equal to the batch sharding size
memory_cost = self.output_data.numel() / batch_sharding_size
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(batch_sharding_size)
# in this case, no communication takes place.
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be sharded on 'i' dimension.
# in this case, the matrix dimensions of lhs is sharded on 'i' dimension.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix})
# in this case, the matrix dimensions of rhs is fully replicated.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
# in this case, the matrix dimensions of output is sharded on 'i' dimension.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix})
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / total_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be sharded on 'k' dimension.
# in this case, the matrix dimensions of lhs is sharded on 'k' dimension.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix})
# in this case, the matrix dimensions of rhs is sharded on 'k' dimension.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix})
# in this case, the matrix dimensions of output is fully replicated.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
batch_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data is fully replicated on matrix dimensions.
memory_cost = self.output_data.numel() / batch_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward activation computation.
if len(mesh_dim_on_matrix) == 1:
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
else:
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be is sharded on 'j' dimension.
# in this case, the matrix dimensions of lhs is fully replicated.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
# in this case, the matrix dimensions of rhs is sharded on 'j' dimension.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix})
# in this case, the matrix dimensions of output is sharded on 'j' dimension.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix})
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / total_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during backward activation computation.
if len(mesh_dim_on_matrix) == 1:
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
else:
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list):
self._split_dim_i(dim_partition_dict, mesh_dim_list)
self._split_dim_k(dim_partition_dict, mesh_dim_list)
self._split_dim_j(dim_partition_dict, mesh_dim_list)
@ignore_sharding_exception
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-2: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-2: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward activation computation.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward and backward activation computation.
communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
communication_cost = communication_cost_backward_activation + communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-1: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during backward activation computation.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _registry_2d_strategies_for_matmul(self):
self._split_lhs_space_both_contract(0, 1)
self._split_lhs_space_both_contract(1, 0)
self._split_rhs_space_both_contract(0, 1)
self._split_rhs_space_both_contract(1, 0)
self._split_lhs_space_rhs_space(0, 1)
self._split_lhs_space_rhs_space(1, 0)
def register_strategy(self) -> StrategiesVector:
MESH_DIM_LIST = [0, 1]
if self.node.target != torch.matmul:
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)
else:
# we only care about the non-computing dimensions,
# therefore, we omit the last two dimensions.
dim_size = self.output_data.dim() - 2
# Both device mesh axises are uesd on batch dimensions
dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
for dim_partition_dict in dim_partition_dicts_2d:
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
# Only one device mesh axis is uesd on batch dimensions
for mesh_dim_index in [0, 1]:
dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
for dim_partition_dict in dim_partition_dicts_1d:
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
# No device mesh axis is uesd on batch dimensions
dim_partition_dict_on_batch_dim = {}
self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim)
self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST)
self._registry_2d_strategies_for_matmul()

View File

@ -1,609 +0,0 @@
import operator
import warnings
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['ConvHandler']
class ConvHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Convolution.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
self._sanity_check()
def _sanity_check(self):
'''
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
assert self.input_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def _generate_compute_cost(self, bs, channel_in, channel_out):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
channel_out(int): The out channel of the conv weight.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
output_size = self.output_data.shape[2:]
output_size_product = reduce(operator.mul, output_size, 1)
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
kernel_size = self.weight.shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1)
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
# compute the backward communication cost to all reduce the input activation grad
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
mesh_dim_1)
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# total communication cost
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation in forward phase.
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute the total cost
communication_cost = communication_cost_forward + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
# This strategy do not need to do all_reduce operation to compute the input activation grad
communication_cost_backward_activation = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute total cost
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# This strategy do NOT need all_reduce during forward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = 1
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce in forward phase
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
# compute the total communication cost
communication_cost = communication_cost_backward_weight + communication_cost_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
self.device_mesh.shape[mesh_dim_1])
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_forward_activation, 0)
# This strategy do NOT need do all_reduce during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, conv, output]
nodes = [node for node in gm.graph.nodes]
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector()
for strategy in conv_handler.strategies_vector:
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
Output:
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]}
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
'''
# SS = SR x RS
self.split_input_batch_weight_out_channel(0, 1)
self.split_input_batch_weight_out_channel(1, 0)
# SR = SR x RR
self.split_input_batch(0)
self.split_input_batch(1)
# SR = SS x SR
self.split_input_both_dim_weight_in_channel(0, 1)
self.split_input_both_dim_weight_in_channel(1, 0)
# RS = RS x SS
self.split_input_in_channel_weight_both_channel(0, 1)
self.split_input_in_channel_weight_both_channel(1, 0)
# RR = RS x SR
self.split_input_in_channel_weight_in_channel(0)
self.split_input_in_channel_weight_in_channel(1)
# RS = RR x RS
self.split_weight_out_channel(0)
self.split_weight_out_channel(1)
# RR= RR x RR
self.non_split()
# S01R = S01R x RR
self.split_1d_parallel_on_input_batch(0, 1)
# RR = RS01 x S01R
self.split_1d_parallel_on_in_channel(0, 1)
return self.strategies_vector
CONV_STRATEGIES_LIST = [
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
]

View File

@ -1,756 +0,0 @@
import operator
from enum import Enum
from functools import reduce
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from .operator_handler import OperatorHandler
from .strategy_generator import IntermediateStrategy, StrategyGenerator
__all__ = ['DotHandler']
class DotProductStrategyGenerator(StrategyGenerator):
"""
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
do not consider bias here.
"""
def validate(self, input, other):
assert input.dim() == 1 and other.dim() == 1
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_one_dim(self, mesh_dim):
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class MatVecStrategyGenerator(StrategyGenerator):
def validate(self, input, other) -> bool:
assert input.dim() > 1 and other.dim() == 1
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class MatMulStrategyGenerator(StrategyGenerator):
"""
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed.
"""
def __init__(self, is_linear: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_linear = is_linear
# as the weight for the linear module is transposed, we can compute
# the correponding dimension indexfor convenience
if is_linear:
self.dim_q = 0
self.dim_p = 1
else:
self.dim_q = 1
self.dim_p = 0
def validate(self, input, other, bias) -> bool:
# make sure the second tensor is a 2D tensor
assert input.dim() > 0 and other.dim() == 2
# make sure bias is of the same dimension
if self.is_linear:
assert bias is None or bias.shape[-1] == other.shape[0]
else:
assert bias is None or bias.shape[-1] == other.shape[1]
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict = {
"input": {
-1: [mesh_dim_0]
},
"other": {
self.dim_p: [mesh_dim_0],
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict = {
"input": {
-1: [mesh_dim]
},
"other": {
self.dim_p: [mesh_dim]
},
"bias": {},
"output": {},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict = {
"input": {},
"other": {
self.dim_q: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
return IntermediateStrategy(name=name,
dim_partition_dict=dim_partition_dict,
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {},
"other": {
self.dim_q: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
class BatchedMatMulStrategyGenerator(StrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_torch_bmm = is_torch_bmm
def validate(self, input, other, bias) -> bool:
if self.is_torch_bmm:
assert input.shape == other.shape
assert input.dim() > 2
assert other.shape[-1] == bias.shape[0]
else:
# TODO: validate these inputs are broadcastable
pass
def split_one_batch_dim(self):
if 1 in self.device_mesh.mesh_shape:
mesh_dim = self.device_mesh.mesh_shape.index(1)
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {
0: [mesh_dim]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
else:
return None
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {},
"output": {
0: mesh_dim_0,
-2: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy = self.split_one_batch_dim()
if strategy:
strategy_list.append(strategy)
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list
class DotHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
# TODO: consider bias addition
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
return compute_cost
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# linear layer weight is transposed during init
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute computation cost
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
# create and register strategy
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict_for_input = {1: [mesh_dim]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
communication_cost = communication_cost_weight_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
activation_memory_cost, 0)
communication_cost = communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
input_grad_memory_cost, 0)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
Output:
'''
# SS = SR x RS
self.split_lhs_space_rhs_space(0, 1)
self.split_lhs_space_rhs_space(1, 0)
# SR = SS x SR
self.split_lhs_space_both_contract(0, 1)
self.split_lhs_space_both_contract(1, 0)
# RS = RS x SS
self.split_rhs_space_both_contract(0, 1)
self.split_rhs_space_both_contract(1, 0)
# RR= RS x SR
self.recompute_split_both_contract(0)
self.recompute_split_both_contract(1)
# RS = RR x RS
self.split_rhs_space_only(0)
self.split_rhs_space_only(1)
# S01R = S01R x RR
self.split_lhs_1st_dim_1d(0, 1)
# RR = RS01 x S01R
self.split_lhs_2nd_dim_1d(0, 1)
# RS01 = RR x RS01
self.split_rhs_2nd_dim_1d(0, 1)
return self.strategies_vector

View File

@ -1,179 +0,0 @@
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['EmbeddingHandler']
class EmbeddingHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, total_sharding_size):
input_shape = self.input_data.shape
weight_shape = self.weight.shape
input_shape_product = reduce(operator.mul, input_shape, 1)
weight_shape_product = reduce(operator.mul, weight_shape, 1)
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@ignore_sharding_exception
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {2: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward_activation = 0
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
'''
# RRS = RR x SS
self.split_weight_both_dim(0, 1)
self.split_weight_both_dim(1, 0)
# SSR = SS x RR
self.split_input_both_dim(0, 1)
self.split_input_both_dim(1, 0)
return self.strategies_vector

View File

@ -1,241 +0,0 @@
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
ignore_sharding_exception,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['LayerNormHandler']
class LayerNormHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of normalization.
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.bias = self.module_named_parameters['bias']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, total_sharding_size):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
norm_kernel_size = self.weight.shape
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
# the total cost is input_batch_product * norm_kernel_product
backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
compute_cost = forward_compute_cost + backward_compute_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
# this operation will not change the shape of input
numel_input = numel_output
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_for_input = dim_partition
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = dim_partition
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list)
# This strategy do not need to do all_reduce operation for activation
communication_cost_forward_activation = 0
communication_cost_backward_activation = 0
if len(total_mesh_dim_list) == 1:
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
total_mesh_dim_list[0])
else:
assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@ignore_sharding_exception
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
total_sharding_size = 1
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
'''
# SR = SR x R with single mesh dim on batch dimensions
self.split_input_batch_single_mesh_dim(0)
self.split_input_batch_single_mesh_dim(1)
# SR = SR x R with both mesh dims on batch dimensions
self.split_input_batch_both_mesh_dim(0, 1)
# RR = RR x R
self.non_split()
return self.strategies_vector

View File

@ -1,149 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, List
from webbrowser import Opera
import torch
import torch.nn as nn
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from .._utils import generate_resharding_costs, generate_sharding_spec
from ..sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler']
class OperatorHandler(ABC):
'''
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
'''
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
handle_backward: bool = True):
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.handle_backward = handle_backward
# find the module and its parameters associated with this node
# this can be used to compute the compute/communication/sharding cost
if self.node.op == 'call_module':
module = node.graph.owning_module.get_submodule(node.target)
named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
module = None
parameters = list(self.node.args)[1]
if isinstance(parameters, Node):
named_parameters = {'weight': parameters._meta_data}
else:
named_parameters = {}
else:
module = None
named_parameters = None
self.module = module
self.module_named_parameters = named_parameters
@abstractmethod
def register_strategy(self) -> StrategiesVector:
"""
Register
"""
pass
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
sharding_spec_for_input):
'''
Compute the memory cost per device with this specific strategy.
Argument:
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
Return:
total_memory_cost(float): total memory cost per device with this specific strategy
activation_cost(float): the memory cost of activation per device with this specific strategy
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
'''
# compute the size of one element with specific dtype
dtype = self.input_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# compute the memory cost of activation
activation_numel = self.output_data.numel()
output_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
output_mesh_dims.extend(mesh_dims)
activation_sharding_size = 1
for mesh_dim in output_mesh_dims:
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
# compute the memory cost of weight
weight_numel = self.weight.numel()
weight_sharding_size = 1
weight_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
weight_mesh_dims.extend(mesh_dims)
for mesh_dim in weight_mesh_dims:
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
# compute the memory cost of input grad
input_grad_numel = self.input_data.numel()
input_grad_sharding_size = 1
input_grad_mesh_dims = []
for sharding_dim, mesh_dims in sharding_spec_for_input.items():
input_grad_mesh_dims.extend(mesh_dims)
for mesh_dim in input_grad_mesh_dims:
input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
memory_cost_forward = activation_memory_cost + weight_memory_cost
memory_cost_backward = input_grad_memory_cost + weight_memory_cost
return (memory_cost_forward,
memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
if hasattr(self.node._meta_data, 'dtype'):
dtype = self.node._meta_data.dtype
else:
assert isinstance(self.node._meta_data,
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
dtype = self.node._meta_data[0].dtype
nodes = self.predecessor_node
return generate_resharding_costs(nodes=nodes,
sharding_specs=sharding_specs,
count_backward=self.handle_backward,
dtype=dtype)
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
return generate_sharding_spec(input_=input_,
device_mesh=self.device_mesh,
dim_partition_dict=dim_partition_dict)
@abstractmethod
def _generate_compute_cost(self, *args, **kwargs):
"""
Compute the flops involved in the node.
"""
pass

View File

@ -1,89 +0,0 @@
import colorsys
import math
import warnings
from copy import deepcopy
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
from .operator_handler import OperatorHandler
class ReshapeHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.output_data = self.node._meta_data
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@ignore_sharding_exception
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict_for_output = {}
if isinstance(self.output_data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
try:
if isinstance(self.output_data, tuple):
output_sharding_spec = []
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
else:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
except AssertionError as e:
warnings.warn(f'{e}')
continue
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = 0
# consider node._meta_data is in type of tuple
memory_cost = 0
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
dim_partition_dict_for_replicate_input = {}
replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data,
dim_partition_dict_for_replicate_input)
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
replicate_input_sharding_spec)
communication_cost = communication_cost["total"]
# generate resharding cost
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
self.strategies_vector.append(sharding_strategy)

View File

@ -1,46 +0,0 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
@dataclass
class IntermediateStrategy:
"""
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
Args:
name (str): name of the sharding strategy.
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
"""
name: str
dim_partition_dict: Dict[str, Dict[int, List[int]]]
all_reduce_axis: List[int] = None
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self) -> List[IntermediateStrategy]:
"""
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass

View File

@ -1,88 +0,0 @@
import math
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['UnaryElementwiseHandler']
class UnaryElementwiseHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.node.op == 'call_module':
target = self.node.target
submod = self.node.graph.owning_module.get_submodule(target)
submod_type = type(submod)
if submod_type == torch.nn.Dropout:
print(f'predecessor nodes of dropout node are {self.predecessor_node}')
input_nodes_len = 0
for check_node in self.predecessor_node:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.'
self.input_data = self.predecessor_node[0]._meta_data
self.input_node = self.predecessor_node[0]
self.output_data = self.node._meta_data
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@ignore_sharding_exception
def register_strategy(self):
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
for index, strategy in enumerate(self.input_node.strategies_vector):
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
continue
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = self.output_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[self.input_node] = [
0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
self.strategies_vector.append(sharding_strategy)

View File

@ -1,188 +0,0 @@
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['WhereHandler']
class WhereHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of torch.where.
"""
def __init__(self, *args, **kwargs):
# TODO: x or y could be scalar
super().__init__(*args, **kwargs)
assert len(self.predecessor_node) == 3
self.condition_data = self.predecessor_node[0]._meta_data
self.x_data = self.predecessor_node[1]._meta_data
self.y_data = self.predecessor_node[2]._meta_data
self.condition = self.predecessor_node[0]
self.x = self.predecessor_node[1]
self.y = self.predecessor_node[2]
self.output_data = self.node._meta_data
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
shape = list(input_.shape)
# padding the shape to the same length as output_data
while len(shape) < self.output_data.dim():
shape.insert(0, 1)
shape = torch.Size(shape)
# if the sharding happens on a size one dimension, we should record it as R.
processed_dim_partition_dict = deepcopy(dim_partition_dict)
for dim_index, _ in dim_partition_dict.items():
if shape[dim_index] == 1:
processed_dim_partition_dict.pop(dim_index)
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=shape,
dim_partition_dict=processed_dim_partition_dict)
return sharding_spec
def _generate_compute_cost(self, total_sharding_size):
lhs_matrix_shape = self.lhs_data.shape[-2:]
rhs_matrix_shape = self.rhs_data.shape[-2:]
batch_dimensions_shape = self.output_data.shape[:-2]
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
compute_cost = reduce(
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
return compute_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
nodes = self.predecessor_node
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
# Then, use the padded input sharding spec to compute the resharding cost.
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
new_entire_shape = list(input_sharding_spec.entire_shape)
while len(new_entire_shape) < len(input_spec.entire_shape):
new_entire_shape.insert(0, 1)
new_entire_shape = torch.Size(new_entire_shape)
new_device_mesh = input_sharding_spec.device_mesh
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
entire_shape=new_entire_shape,
dim_partition_dict=new_dim_partition_dict)
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
total_resharding_cost = total_resharding_cost['total']
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
sharding_spec_list = []
check_duplicated_list = []
for output_dim_partition_dict in dim_partition_list:
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
break
sharding_seq = output_sharding_spec.sharding_sequence
if sharding_seq not in check_duplicated_list:
check_duplicated_list.append(sharding_seq)
sharding_spec_list.append(output_sharding_spec)
return sharding_spec_list
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list = []
dim_size = self.output_data.dim()
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
return output_sharding_spec_list
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input)
sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input)
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs(
[sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y])
# compute the computation cost of this strategy
sharding_dims = []
for mesh_dims in dim_partition_dict_for_output.values():
for mesh_dim in mesh_dims:
sharding_dims.append(self.device_mesh.shape[mesh_dim])
sharding_size = reduce(operator.mul, sharding_dims, 1)
memory_cost = self.output_data.numel() / sharding_size
compute_cost = memory_cost
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_condition, sharding_spec_for_x,
sharding_spec_for_y))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
MESH_DIM_LIST = [0, 1]
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)

View File

@ -1,11 +0,0 @@
from dataclasses import dataclass
__all__ = ['SolverOptions']
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
fast: bool = False

View File

@ -1,91 +0,0 @@
from copy import deepcopy
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import operator
import torch
from functools import reduce
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
from .constants import *
__all__ = ['ShardingStrategy', 'StrategiesVector']
@dataclass
class ShardingStrategy:
'''
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
and costs information using in solver.
Argument:
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
compute_cost(float): Computation cost to complete this strategy.(default to 0)
communication_cost(float): Communication cost to complete this strategy.(default to 0)
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
'''
name: str
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
compute_cost: float = 0.
communication_cost: float = 0.
memory_cost: float = 0.
resharding_costs: Dict[Node, List[float]] = None
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
# Therefore, we could process them at the specific op(operator.getitem)
input_shardings: List[ShardingSpec] = None
class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
'''
def __init__(self, node: Node):
super().__init__()
self.node = node
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys())
def check_merge(self):
merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into source nodes
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
return merge_label

View File

@ -1,469 +0,0 @@
import multiprocessing
import time
import warnings
from typing import Dict
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from .constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
__all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_costs.append(strategy.compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
origin_communication_cost = strategy.communication_cost
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(strategy.communication_cost)
# temporarily we just consider the forward memory cost
memory_cost = strategy.memory_cost
if isinstance(memory_cost, tuple):
memory_costs.append(memory_cost[0])
else:
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list

View File

@ -1,426 +0,0 @@
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ._utils import generate_resharding_costs, generate_sharding_spec
from .constants import *
from .op_handler import *
from .options import SolverOptions
from .sharding_strategy import ShardingStrategy, StrategiesVector
__all__ = ['StrategiesConstructor']
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def _is_bcast_matmul(self, node):
is_bcast_matmul = False
if node.target is torch.matmul and len(node.args) == 2:
lhs_data = node.args[0]._meta_data
rhs_data = node.args[1]._meta_data
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
is_bcast_matmul = True
return is_bcast_matmul
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
input_nodes_len = 0
for check_node in strategies_vector.predecessor_nodes:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
# input_nodes_len = len(strategies_vector.predecessor_nodes)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_placeholder)
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# call_module node
if node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
dot_handler.register_strategy()
# element-wise module
elif submod_type in ELEMENTWISE_MODULE_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# BatchNormNd module
elif submod_type in BATCHNORM_MODULE_OP:
# create sharding strategy for element-wise module
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
norm_handler.register_strategy()
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# assert False
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# TODO: add sharding constraints on image dimension
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
# create sharding strategy for element-wise module
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# embedding module
elif submod_type in EMBEDDING_MODULE_OP:
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
embedding_handler.register_strategy()
# layernorm module
elif submod_type in LAYERNORM_MODULE_OP:
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
layernorm_handler.register_strategy()
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
# call_function node
if node.op == 'call_function':
target = node.target
# conv function
if target in CONV_FUNC_OP:
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()
# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass
elif input_nodes_len == 2:
# one of x or y is type of scalar
pass
else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()
# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# bcast op
elif target in BCAST_FUNC_OP:
if isinstance(node._meta_data, torch.Tensor):
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
bcast_op_handler.register_strategy()
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
entire_shape_input = input_sharding_spec.entire_shape
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
if dim in dim_partition_dict_input:
# We need to make the action dimension in replicate status
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
dim_partition_dict_for_input.pop(dim)
new_input_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_input,
dim_partition_dict=dim_partition_dict_for_input)
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[new_input_sharding_spec])
else:
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partion_dict=dim_partition_dict_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# operator.getitem
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.arange function
elif target == torch.arange:
name = f'FULLY REPLICATED ARANGE'
entire_shape_output = node._meta_data.shape
dim_partition_dict_for_output = {}
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
memory_cost = node._meta_data.numel()
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=0,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy)
# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm):
pass
# other function
else:
raise RuntimeError(f'{target} function is NOT supported now.')
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size,):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
elif method in RESHAPE_METHOD_OP:
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# print(strategies_vector)
# if len(strategies_vector) == 0:
# print(node)
# assert False
else:
raise RuntimeError(f'{method} function is NOT supported now.')
# output node
if node.op == 'output':
if self.solver_options.fast:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
input_sharding_specs = []
for input_node in input_nodes:
dim_partition_dict_for_input = {}
entire_shape = input_node._meta_data.shape
sharding_spec = ShardingSpec(self.device_mesh,
entire_shape,
dim_partition_dict=dim_partition_dict_for_input)
input_sharding_specs.append(sharding_spec)
dim_partition_dict = {}
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=tuple(input_sharding_specs))
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)

View File

@ -8,14 +8,9 @@ from torch.fx.graph import Graph
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
solver_options = SolverOptions()
if solver_preference == 'standard':
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')
if dataloader_option == 'replicated':
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
if shard_option == 'standard':
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')
solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
graph = tracer.trace(root=model, meta_args=meta_args)
gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh)
strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
if load_solver_solution:
solution = torch.load(solution_path)
else:
@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,

View File

@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
@ -31,6 +30,6 @@ __all__ = [
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
'TransposeHandler', 'SplitHandler'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
'SplitHandler'
]

View File

@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
LinearProjectionStrategyGenerator(op_data_mapping,
self.device_mesh,
linear_projection_type='linear',
solver_perference=self.solver_perference))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:

View File

@ -5,7 +5,7 @@ import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -32,19 +32,19 @@ class NodeHandler(ABC):
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
) -> None:
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shard_option = shard_option
self.solver_perference = solver_perference
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
@ -187,15 +187,24 @@ class NodeHandler(ABC):
remove_strategy_list = []
for strategy in self.strategies_vector:
shard_level = 0
shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
shard_level += len(shard_axis)
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
for shard_axis in shard_axes:
if shard_axis not in shard_axis_list:
shard_axis_list.append(shard_axis)
shard_level = len(shard_axis_list)
using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
if self.shard_option == ShardOption.SHARD and shard_level == 0:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.SHARD_LAST_AXIS:
if shard_level != 1 or using_last_axis == False:
remove_strategy_list.append(strategy)
for strategy in remove_strategy_list:
self.strategies_vector.remove(strategy)

View File

@ -1,17 +0,0 @@
from enum import Enum
__all__ = ['ShardOption']
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD = 0
SHARD = 1
FULL_SHARD = 2

Some files were not shown because too many files have changed in this diff Show More