mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' into fix/format
commit
93b788b95a
|
@ -96,6 +96,7 @@ jobs:
|
|||
|
||||
- name: Store TensorNVMe Cache
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
cp -p -r ./build /github/home/tensornvme_cache/
|
||||
|
||||
- name: Checkout Colossal-AI
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
name: Build Documentation upon Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'version.txt'
|
||||
- 'docs/'
|
||||
types:
|
||||
- closed
|
||||
|
||||
jobs:
|
||||
build-doc:
|
||||
name: Trigger Documentation Build Workflow
|
||||
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: trigger workflow in ColossalAI-Documentation
|
||||
run: |
|
||||
curl \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${GH_TOKEN}"\
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \
|
||||
-d '{"ref":"main"}'
|
||||
env:
|
||||
GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}}
|
|
@ -0,0 +1,42 @@
|
|||
name: Run ChatGPT examples
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [synchronize, opened, reopened]
|
||||
paths:
|
||||
- 'applications/ChatGPT/chatgpt/**'
|
||||
- 'applications/ChatGPT/requirements.txt'
|
||||
- 'applications/ChatGPT/setup.py'
|
||||
- 'applications/ChatGPT/examples/**'
|
||||
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: Run ChatGPT examples
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout ColossalAI
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install ColossalAI and ChatGPT
|
||||
run: |
|
||||
pip install -v .
|
||||
cd applications/ChatGPT
|
||||
pip install -v .
|
||||
pip install -r examples/requirements.txt
|
||||
|
||||
- name: Execute Examples
|
||||
run: |
|
||||
cd applications/ChatGPT
|
||||
./examples/test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
MAX_JOBS: 8
|
||||
PROMPT_PATH: /data/scratch/chatgpt/prompts.csv
|
|
@ -0,0 +1,42 @@
|
|||
name: Run ChatGPT unit tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [synchronize, opened, reopened]
|
||||
paths:
|
||||
- 'applications/ChatGPT/chatgpt/**'
|
||||
- 'applications/ChatGPT/requirements.txt'
|
||||
- 'applications/ChatGPT/setup.py'
|
||||
- 'applications/ChatGPT/requirements-test.txt'
|
||||
- 'applications/ChatGPT/tests/**'
|
||||
- 'applications/ChatGPT/pytest.ini'
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: Run ChatGPT unit tests
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout ColossalAI
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install ColossalAI and ChatGPT
|
||||
run: |
|
||||
pip install -v .
|
||||
cd applications/ChatGPT
|
||||
pip install -v .
|
||||
pip install -r requirements-test.txt
|
||||
|
||||
- name: Execute Unit Testing
|
||||
run: |
|
||||
cd applications/ChatGPT
|
||||
pytest tests/
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
MAX_JOBS: 8
|
|
@ -292,7 +292,13 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s
|
|||
y = []
|
||||
|
||||
if len(total_engagement_count) > 0:
|
||||
ranking = []
|
||||
for name, count in total_engagement_count.items():
|
||||
ranking.append((name, count))
|
||||
|
||||
ranking.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for name, count in ranking:
|
||||
x.append(count)
|
||||
y.append(name)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/)
|
||||
|
||||
Colossal-AI: 一个面向大模型时代的通用深度学习系统
|
||||
Colossal-AI: 让AI大模型更低成本、方便易用、高效扩展
|
||||
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
|
||||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
|
@ -23,10 +23,10 @@
|
|||
</div>
|
||||
|
||||
## 新闻
|
||||
* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
|
||||
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
|
||||
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
|
||||
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
|
||||
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
|
||||
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
|
||||
|
||||
|
||||
|
@ -64,6 +64,7 @@
|
|||
<li>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
|
||||
<ul>
|
||||
<li><a href="#ChatGPT">ChatGPT: 低成本复现ChatGPT完整流程</a></li>
|
||||
<li><a href="#AIGC">AIGC: 加速 Stable Diffusion</a></li>
|
||||
<li><a href="#生物医药">生物医药: 加速AlphaFold蛋白质结构预测</a></li>
|
||||
</ul>
|
||||
|
@ -102,7 +103,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
- 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行
|
||||
- [序列并行](https://arxiv.org/abs/2105.13120)
|
||||
- [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054)
|
||||
- [自动并行](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
|
||||
- [自动并行](https://arxiv.org/abs/2302.02599)
|
||||
- 异构内存管理
|
||||
- [PatrickStar](https://arxiv.org/abs/2108.05818)
|
||||
- 使用友好
|
||||
|
@ -209,6 +210,29 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
## Colossal-AI 成功案例
|
||||
### ChatGPT
|
||||
低成本复现[ChatGPT](https://openai.com/blog/chatgpt/)完整流程 [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[博客]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
|
||||
<p id="ChatGPT_scaling" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
|
||||
</p>
|
||||
|
||||
- 最高可提升单机训练速度7.73倍,单卡推理速度1.42倍
|
||||
|
||||
<p id="ChatGPT-1GPU" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
|
||||
</p>
|
||||
|
||||
- 单卡模型容量最多提升10.3倍
|
||||
- 最小demo训练流程最低仅需1.62GB显存 (任意消费级GPU)
|
||||
|
||||
<p id="inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
|
||||
</p>
|
||||
|
||||
- 提升单卡的微调模型容量3.7倍
|
||||
- 同时保持高速运行
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
### AIGC
|
||||
加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion)
|
||||
|
|
31
README.md
31
README.md
|
@ -3,7 +3,7 @@
|
|||
|
||||
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/)
|
||||
|
||||
Colossal-AI: A Unified Deep Learning System for Big Model Era
|
||||
Colossal-AI: Making big AI models cheaper, easier, and scalable
|
||||
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> Paper </a> |
|
||||
<a href="https://www.colossalai.org/"> Documentation </a> |
|
||||
|
@ -24,10 +24,10 @@
|
|||
</div>
|
||||
|
||||
## Latest News
|
||||
* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
|
||||
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
|
||||
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
|
||||
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
|
||||
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
|
||||
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
|
||||
|
||||
## Table of Contents
|
||||
|
@ -64,6 +64,7 @@
|
|||
<li>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI for Real World Applications</a>
|
||||
<ul>
|
||||
<li><a href="#ChatGPT">ChatGPT: Low-cost ChatGPT Equivalent Implementation Process</a></li>
|
||||
<li><a href="#AIGC">AIGC: Acceleration of Stable Diffusion</a></li>
|
||||
<li><a href="#Biomedicine">Biomedicine: Acceleration of AlphaFold Protein Structure</a></li>
|
||||
</ul>
|
||||
|
@ -104,7 +105,7 @@ distributed training and inference in a few lines.
|
|||
- 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism
|
||||
- [Sequence Parallelism](https://arxiv.org/abs/2105.13120)
|
||||
- [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054)
|
||||
- [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
|
||||
- [Auto-Parallelism](https://arxiv.org/abs/2302.02599)
|
||||
|
||||
- Heterogeneous Memory Management
|
||||
- [PatrickStar](https://arxiv.org/abs/2108.05818)
|
||||
|
@ -211,6 +212,30 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
|
|||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Colossal-AI in the Real World
|
||||
### ChatGPT
|
||||
A low-cost [ChatGPT](https://openai.com/blog/chatgpt/) equivalent implementation process. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[blog]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
|
||||
<p id="ChatGPT_scaling" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
|
||||
</p>
|
||||
|
||||
- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
|
||||
|
||||
<p id="ChatGPT-1GPU" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
|
||||
</p>
|
||||
|
||||
- Up to 10.3x growth in model capacity on one GPU
|
||||
- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
|
||||
|
||||
<p id="inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
|
||||
</p>
|
||||
|
||||
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
||||
- Keep in a sufficiently high running speed
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
|
||||
### AIGC
|
||||
Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion).
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/.build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# macos
|
||||
*.DS_Store
|
||||
#data/
|
||||
|
||||
docs/.build
|
||||
|
||||
# pytorch checkpoint
|
||||
*.pt
|
||||
|
||||
# ignore version.py generated by setup.py
|
||||
colossalai/version.py
|
|
@ -0,0 +1,202 @@
|
|||
Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2021- HPC-AI Technology Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,114 @@
|
|||
# RLHF - Colossal-AI
|
||||
|
||||
Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by Colossal-AI. It supports distributed training and offloading, which can fit extremly large models. More details can be found in the [blog](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt).
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/chatgpt.png" width=700/>
|
||||
</p>
|
||||
|
||||
## Training process (step 3)
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/experience.jpg" width=500/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/train.jpg" width=500/>
|
||||
</p>
|
||||
|
||||
|
||||
## Install
|
||||
```shell
|
||||
pip install .
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
The main entrypoint is `Trainer`. We only support PPO trainer now. We support many training strategies:
|
||||
|
||||
- NaiveStrategy: simplest strategy. Train on single GPU.
|
||||
- DDPStrategy: use `torch.nn.parallel.DistributedDataParallel`. Train on multi GPUs.
|
||||
- ColossalAIStrategy: use Gemini and Zero of ColossalAI. It eliminates model duplication on each GPU and supports offload. It's very useful when training large models on multi GPUs.
|
||||
|
||||
Simplest usage:
|
||||
|
||||
```python
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy
|
||||
|
||||
strategy = ColossalAIStrategy()
|
||||
|
||||
with strategy.model_init_context():
|
||||
# init your model here
|
||||
actor = Actor()
|
||||
critic = Critic()
|
||||
|
||||
trainer = PPOTrainer(actor = actor, critic= critic, strategy, ...)
|
||||
|
||||
trainer.fit(dataset, ...)
|
||||
```
|
||||
|
||||
For more details, see `examples/`.
|
||||
|
||||
We also support training reward model with true-world data. See `examples/train_reward_model.py`.
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] implement PPO training
|
||||
- [x] implement training reward model
|
||||
- [x] support LoRA
|
||||
- [ ] implement PPO-ptx fine-tuning
|
||||
- [ ] integrate with Ray
|
||||
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL)
|
||||
|
||||
## Invitation to open-source contribution
|
||||
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build an ecosystem with Colossal-AI, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
|
||||
|
||||
You may contact us or participate in the following ways:
|
||||
1. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) or submitting a [PR](https://github.com/hpcaitech/ColossalAI/pulls) on GitHub
|
||||
2. Join the Colossal-AI community on
|
||||
[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
|
||||
and [WeChat](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
|
||||
3. Check out and fill in the [cooperation proposal](https://www.hpc-ai.tech/partners)
|
||||
4. Send your proposal to email contact@hpcaitech.com
|
||||
|
||||
Thanks so much to all of our amazing contributors!
|
||||
|
||||
## Quick Preview
|
||||
<p id="ChatGPT_scaling" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
|
||||
</p>
|
||||
|
||||
- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
|
||||
|
||||
<p id="ChatGPT-1GPU" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
|
||||
</p>
|
||||
|
||||
- Up to 10.3x growth in model capacity on one GPU
|
||||
- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
|
||||
|
||||
<p id="inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
|
||||
</p>
|
||||
|
||||
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
||||
- Keep in a sufficiently high running speed
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@article{Hu2021LoRALA,
|
||||
title = {LoRA: Low-Rank Adaptation of Large Language Models},
|
||||
author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2106.09685}
|
||||
}
|
||||
|
||||
@article{ouyang2022training,
|
||||
title={Training language models to follow instructions with human feedback},
|
||||
author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others},
|
||||
journal={arXiv preprint arXiv:2203.02155},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,94 @@
|
|||
# Benchmarks
|
||||
|
||||
## Benchmark GPT on dummy prompt data
|
||||
|
||||
We provide various GPT models (string in parentheses is the corresponding model name used in this script):
|
||||
|
||||
- GPT2-S (s)
|
||||
- GPT2-M (m)
|
||||
- GPT2-L (l)
|
||||
- GPT2-XL (xl)
|
||||
- GPT2-4B (4b)
|
||||
- GPT2-6B (6b)
|
||||
- GPT2-8B (8b)
|
||||
- GPT2-10B (10b)
|
||||
- GPT2-12B (12b)
|
||||
- GPT2-15B (15b)
|
||||
- GPT2-18B (18b)
|
||||
- GPT2-20B (20b)
|
||||
- GPT2-24B (24b)
|
||||
- GPT2-28B (28b)
|
||||
- GPT2-32B (32b)
|
||||
- GPT2-36B (36b)
|
||||
- GPT2-40B (40b)
|
||||
- GPT3 (175b)
|
||||
|
||||
We also provide various training strategies:
|
||||
|
||||
- ddp: torch DDP
|
||||
- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3
|
||||
- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload
|
||||
- colossalai_zero2: ColossalAI zero2
|
||||
- colossalai_zero2_cpu: ColossalAI zero2-offload
|
||||
- colossalai_zero1: ColossalAI zero1
|
||||
- colossalai_zero1_cpu: ColossalAI zero1-offload
|
||||
|
||||
We only support `torchrun` to launch now. E.g.
|
||||
|
||||
```shell
|
||||
# run GPT2-S on single-node single-GPU with min batch size
|
||||
torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1
|
||||
# run GPT2-XL on single-node 4-GPU
|
||||
torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2
|
||||
# run GPT3 on 8-node 8-GPU
|
||||
torchrun --nnodes 8 --nproc_per_node 8 \
|
||||
--rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \
|
||||
benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini
|
||||
```
|
||||
|
||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
||||
|
||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
||||
|
||||
We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script.
|
||||
|
||||
Usage:
|
||||
|
||||
```shell
|
||||
# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh
|
||||
# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2
|
||||
# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2 ddp
|
||||
# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2 ddp l
|
||||
```
|
||||
|
||||
## Benchmark OPT with LoRA on dummy prompt data
|
||||
|
||||
We provide various OPT models (string in parentheses is the corresponding model name used in this script):
|
||||
|
||||
- OPT-125M (125m)
|
||||
- OPT-350M (350m)
|
||||
- OPT-700M (700m)
|
||||
- OPT-1.3B (1.3b)
|
||||
- OPT-2.7B (2.7b)
|
||||
- OPT-3.5B (3.5b)
|
||||
- OPT-5.5B (5.5b)
|
||||
- OPT-6.7B (6.7b)
|
||||
- OPT-10B (10b)
|
||||
- OPT-13B (13b)
|
||||
|
||||
We only support `torchrun` to launch now. E.g.
|
||||
|
||||
```shell
|
||||
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
|
||||
torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
|
||||
# run OPT-350M with lora_rank=4 on single-node 4-GPU
|
||||
torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4
|
||||
```
|
||||
|
||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
||||
|
||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
|
@ -0,0 +1,180 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> GPT2Config:
|
||||
model_map = {
|
||||
's': GPT2Config(),
|
||||
'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
|
||||
'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
|
||||
'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
|
||||
'2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
|
||||
'4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
|
||||
'6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
|
||||
'8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
|
||||
'10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
|
||||
'12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
|
||||
'15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
|
||||
'18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
|
||||
'20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
|
||||
'24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
|
||||
'28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
|
||||
'32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
|
||||
'36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
|
||||
'40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
|
||||
'175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown model "{model_name}"')
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = GPTActor(config=model_config).cuda()
|
||||
critic = GPTCritic(config=model_config).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='s')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env bash
|
||||
# Usage: $0 <?number-of-gpus> <?strategy> <?model>
|
||||
set -xu
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
|
||||
PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
function tune_batch_size() {
|
||||
# we found when experience batch size is equal to train batch size
|
||||
# peak CUDA memory usage of making experience phase is less than or equal to that of training phase
|
||||
# thus, experience batch size can be larger than or equal to train batch size
|
||||
for bs in 1 2 4 8 16 32 64 128 256; do
|
||||
torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1
|
||||
done
|
||||
}
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
num_gpus=(1 2 4 8)
|
||||
else
|
||||
num_gpus=($1)
|
||||
fi
|
||||
|
||||
if [ $# -le 1 ]; then
|
||||
strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu")
|
||||
else
|
||||
strategies=($2)
|
||||
fi
|
||||
|
||||
if [ $# -le 2 ]; then
|
||||
models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b")
|
||||
else
|
||||
models=($3)
|
||||
fi
|
||||
|
||||
|
||||
for num_gpu in ${num_gpus[@]}; do
|
||||
for strategy in ${strategies[@]}; do
|
||||
for model in ${models[@]}; do
|
||||
tune_batch_size $num_gpu $model $strategy || break
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,175 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> OPTConfig:
|
||||
model_map = {
|
||||
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
|
||||
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
||||
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
||||
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
|
||||
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
|
||||
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
||||
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
||||
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
|
||||
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
||||
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown model "{model_name}"')
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='125m')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=4)
|
||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,3 @@
|
|||
from .reward_dataset import RewardDataset
|
||||
|
||||
__all__ = ['RewardDataset']
|
|
@ -0,0 +1,52 @@
|
|||
from typing import Callable
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class RewardDataset(Dataset):
|
||||
"""
|
||||
Dataset for reward model
|
||||
|
||||
Args:
|
||||
dataset: dataset for reward model
|
||||
tokenizer: tokenizer for reward model
|
||||
max_length: max length of input
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
for data in tqdm(dataset):
|
||||
prompt = data['prompt']
|
||||
|
||||
chosen = prompt + data['chosen'] + "<|endoftext|>"
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
|
||||
reject = prompt + data['rejected'] + "<|endoftext|>"
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
|
@ -0,0 +1,4 @@
|
|||
from .base import Experience, ExperienceMaker
|
||||
from .naive import NaiveExperienceMaker
|
||||
|
||||
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
|
|
@ -0,0 +1,77 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn.actor import Actor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experience:
|
||||
"""Experience is a batch of data.
|
||||
These data should have the the sequence length and number of actions.
|
||||
Left padding for sequences is applied.
|
||||
|
||||
Shapes of each tensor:
|
||||
sequences: (B, S)
|
||||
action_log_probs: (B, A)
|
||||
values: (B)
|
||||
reward: (B)
|
||||
advatanges: (B)
|
||||
attention_mask: (B, S)
|
||||
action_mask: (B, A)
|
||||
|
||||
"A" is the number of actions.
|
||||
"""
|
||||
sequences: torch.Tensor
|
||||
action_log_probs: torch.Tensor
|
||||
values: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
advantages: torch.Tensor
|
||||
attention_mask: Optional[torch.LongTensor]
|
||||
action_mask: Optional[torch.BoolTensor]
|
||||
|
||||
@torch.no_grad()
|
||||
def to_device(self, device: torch.device) -> None:
|
||||
self.sequences = self.sequences.to(device)
|
||||
self.action_log_probs = self.action_log_probs.to(device)
|
||||
self.values = self.values.to(device)
|
||||
self.reward = self.reward.to(device)
|
||||
self.advantages = self.advantages.to(device)
|
||||
if self.attention_mask is not None:
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
if self.action_mask is not None:
|
||||
self.action_mask = self.action_mask.to(device)
|
||||
|
||||
def pin_memory(self):
|
||||
self.sequences = self.sequences.pin_memory()
|
||||
self.action_log_probs = self.action_log_probs.pin_memory()
|
||||
self.values = self.values.pin_memory()
|
||||
self.reward = self.reward.pin_memory()
|
||||
self.advantages = self.advantages.pin_memory()
|
||||
if self.attention_mask is not None:
|
||||
self.attention_mask = self.attention_mask.pin_memory()
|
||||
if self.action_mask is not None:
|
||||
self.action_mask = self.action_mask.pin_memory()
|
||||
return self
|
||||
|
||||
|
||||
class ExperienceMaker(ABC):
|
||||
|
||||
def __init__(self,
|
||||
actor: Actor,
|
||||
critic: nn.Module,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
kl_coef: float = 0.1) -> None:
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.reward_model = reward_model
|
||||
self.initial_model = initial_model
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
@abstractmethod
|
||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
pass
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
from chatgpt.nn.utils import compute_reward, normalize
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
|
||||
class NaiveExperienceMaker(ExperienceMaker):
|
||||
"""
|
||||
Naive experience maker.
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
self.initial_model.eval()
|
||||
self.reward_model.eval()
|
||||
|
||||
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
|
||||
return_action_mask=True,
|
||||
**generate_kwargs)
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
action_log_probs = self.actor(sequences, num_actions, attention_mask)
|
||||
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
|
||||
value = self.critic(sequences, action_mask, attention_mask)
|
||||
r = self.reward_model(sequences, attention_mask)
|
||||
|
||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||||
|
||||
advantage = reward - value
|
||||
# TODO(ver217): maybe normalize adv
|
||||
if advantage.ndim == 1:
|
||||
advantage = advantage.unsqueeze(-1)
|
||||
|
||||
return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)
|
|
@ -0,0 +1,18 @@
|
|||
from .actor import Actor
|
||||
from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_rm import BLOOMRM
|
||||
from .critic import Critic
|
||||
from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_rm import GPTRM
|
||||
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_rm import OPTRM
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor',
|
||||
'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM'
|
||||
]
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .generation import generate
|
||||
from .lora import LoRAModule
|
||||
from .utils import log_probs_from_logits
|
||||
|
||||
|
||||
class Actor(LoRAModule):
|
||||
"""
|
||||
Actor model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Actor Model.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from .actor import Actor
|
||||
|
||||
|
||||
class BLOOMActor(Actor):
|
||||
"""
|
||||
BLOOM Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from .critic import Critic
|
||||
|
||||
|
||||
class BLOOMCritic(Critic):
|
||||
"""
|
||||
BLOOM Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from .reward_model import RewardModel
|
||||
|
||||
|
||||
class BLOOMRM(RewardModel):
|
||||
"""
|
||||
BLOOM Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .lora import LoRAModule
|
||||
from .utils import masked_mean
|
||||
|
||||
|
||||
class Critic(LoRAModule):
|
||||
"""
|
||||
Critic model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Critic model.
|
||||
value_head (nn.Module): Value head to get value.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
|
||||
|
||||
if action_mask is not None:
|
||||
num_actions = action_mask.size(1)
|
||||
values = values[:, -num_actions:]
|
||||
value = masked_mean(values, action_mask, dim=1)
|
||||
return value
|
||||
value = values.mean(dim=1).squeeze(1)
|
||||
return value
|
|
@ -0,0 +1,137 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
return processor_list
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if update_model_kwargs_fn is not None:
|
||||
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
|
||||
# stop when each sentence is finished if early_stopping=True
|
||||
if early_stopping and unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
def generate(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model
|
||||
input_ids (torch.Tensor): input sequence
|
||||
max_length (int): max length of the returned sequence
|
||||
num_beams (int, optional): number of beams. Defaults to 1.
|
||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||
"""
|
||||
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
||||
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
||||
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
||||
if is_greedy_gen_mode:
|
||||
# run greedy search
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
|
||||
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
||||
if "past_key_values" in outputs:
|
||||
model_kwargs["past"] = outputs["past_key_values"]
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def opt_prepare_inputs_fn(input_ids: torch.Tensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs) -> dict:
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
||||
def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs) -> dict:
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from .actor import Actor
|
||||
|
||||
|
||||
class GPTActor(Actor):
|
||||
"""
|
||||
GPT Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2LMHeadModel(config)
|
||||
else:
|
||||
model = GPT2LMHeadModel(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model)
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from .critic import Critic
|
||||
|
||||
|
||||
class GPTCritic(Critic):
|
||||
"""
|
||||
GPT Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head)
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from .reward_model import RewardModel
|
||||
|
||||
|
||||
class GPTRM(RewardModel):
|
||||
"""
|
||||
GPT Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head)
|
|
@ -0,0 +1,127 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
bias: Optional[nn.Parameter],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
lora.LoRALayer.__init__(self,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
if self.r > 0 and not self.merged:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
||||
else:
|
||||
convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
class LoRAModule(nn.Module):
|
||||
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
|
||||
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
|
||||
|
||||
Args:
|
||||
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
|
||||
lora_train_bias (str, optional): Whether LoRA train biases.
|
||||
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
|
||||
Defaults to 'none'.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__()
|
||||
self.lora_rank = lora_rank
|
||||
self.lora_train_bias = lora_train_bias
|
||||
|
||||
def convert_to_lora(self) -> None:
|
||||
if self.lora_rank <= 0:
|
||||
return
|
||||
convert_to_lora_recursively(self, self.lora_rank)
|
||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
|
@ -0,0 +1,105 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import masked_mean
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
"""
|
||||
GPT Language Model Loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
class PolicyLoss(nn.Module):
|
||||
"""
|
||||
Policy Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ValueLoss(nn.Module):
|
||||
"""
|
||||
Value Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.4) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
values: torch.Tensor,
|
||||
old_values: torch.Tensor,
|
||||
reward: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - reward)**2
|
||||
surr2 = (values - reward)**2
|
||||
loss = torch.max(surr1, surr2)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class PPOPtxActorLoss(nn.Module):
|
||||
"""
|
||||
To Do:
|
||||
|
||||
PPO-ptx Actor Loss
|
||||
"""
|
||||
|
||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
||||
super().__init__()
|
||||
self.pretrain_coef = pretrain_coef
|
||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
||||
self.pretrain_loss_fn = pretrain_loss_fn
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
lm_logits: torch.Tensor,
|
||||
lm_input_ids: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
||||
return policy_loss + self.pretrain_coef * lm_loss
|
||||
|
||||
|
||||
class PairWiseLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
"""
|
||||
|
||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
probs = torch.sigmoid(chosen_reward - reject_reward)
|
||||
log_probs = torch.log(probs)
|
||||
loss = -log_probs.mean()
|
||||
return loss
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from .actor import Actor
|
||||
|
||||
|
||||
class OPTActor(Actor):
|
||||
"""
|
||||
OPT Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
from .critic import Critic
|
||||
|
||||
|
||||
class OPTCritic(Critic):
|
||||
"""
|
||||
OPT Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
from .reward_model import RewardModel
|
||||
|
||||
|
||||
class OPTRM(RewardModel):
|
||||
"""
|
||||
OPT Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .lora import LoRAModule
|
||||
|
||||
|
||||
class RewardModel(LoRAModule):
|
||||
"""
|
||||
Reward model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Reward model.
|
||||
value_head (nn.Module): Value head to get reward score.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
value_head: Optional[nn.Module] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
if value_head is not None:
|
||||
if value_head.out_features != 1:
|
||||
raise ValueError("The value head of reward model's output dim should be 1!")
|
||||
self.value_head = value_head
|
||||
|
||||
else:
|
||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
values = self.value_head(last_hidden_states)[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
return value
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
|
||||
Args:
|
||||
log_probs: Log probabilities of the new distribution.
|
||||
log_probs_base: Log probabilities of the base distribution.
|
||||
action_mask: Mask for actions.
|
||||
"""
|
||||
|
||||
log_ratio = log_probs - log_probs_base
|
||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
||||
if action_mask is not None:
|
||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
||||
return approx_kl
|
||||
approx_kl = approx_kl.mean(dim=1)
|
||||
return approx_kl
|
||||
|
||||
|
||||
def compute_reward(r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
reward = r - kl_coef * kl
|
||||
return reward
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
tensor = tensor.sum(dim=dim)
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
mean = masked_mean(tensor, mask, dim=dim)
|
||||
mean_centered = tensor - mean
|
||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
||||
mean = tensor.mean(dim)
|
||||
mean_centered = tensor - mean
|
||||
var = (mean_centered**2).mean(dim)
|
||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
||||
return norm
|
||||
|
||||
|
||||
def convert_to_lora(model: nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True):
|
||||
if lora_rank > min(input_size, output_size):
|
||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
module._modules[name] = lora.Linear(input_size,
|
||||
output_size,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
merge_weights=merge_weights)
|
|
@ -0,0 +1,4 @@
|
|||
from .base import ReplayBuffer
|
||||
from .naive import NaiveReplayBuffer
|
||||
|
||||
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
|
|
@ -0,0 +1,43 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from chatgpt.experience_maker.base import Experience
|
||||
|
||||
|
||||
class ReplayBuffer(ABC):
|
||||
"""Replay buffer base class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.sample_batch_size = sample_batch_size
|
||||
# limit <= 0 means unlimited
|
||||
self.limit = limit
|
||||
|
||||
@abstractmethod
|
||||
def append(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sample(self) -> Experience:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx: int) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collate_fn(self, batch: Any) -> Experience:
|
||||
pass
|
|
@ -0,0 +1,57 @@
|
|||
import random
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from chatgpt.experience_maker.base import Experience
|
||||
|
||||
from .base import ReplayBuffer
|
||||
from .utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
|
||||
|
||||
class NaiveReplayBuffer(ReplayBuffer):
|
||||
"""Naive replay buffer class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
|
||||
super().__init__(sample_batch_size, limit)
|
||||
self.cpu_offload = cpu_offload
|
||||
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
|
||||
# TODO(ver217): add prefetch
|
||||
self.items: List[BufferItem] = []
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
if self.cpu_offload:
|
||||
experience.to_device(torch.device('cpu'))
|
||||
items = split_experience_batch(experience)
|
||||
self.items.extend(items)
|
||||
if self.limit > 0:
|
||||
samples_to_remove = len(self.items) - self.limit
|
||||
if samples_to_remove > 0:
|
||||
self.items = self.items[samples_to_remove:]
|
||||
|
||||
def clear(self) -> None:
|
||||
self.items.clear()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self) -> Experience:
|
||||
items = random.sample(self.items, self.sample_batch_size)
|
||||
experience = make_experience_batch(items)
|
||||
if self.cpu_offload:
|
||||
experience.to_device(self.target_device)
|
||||
return experience
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx: int) -> BufferItem:
|
||||
return self.items[idx]
|
||||
|
||||
def collate_fn(self, batch) -> Experience:
|
||||
experience = make_experience_batch(batch)
|
||||
return experience
|
|
@ -0,0 +1,73 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from chatgpt.experience_maker.base import Experience
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferItem:
|
||||
"""BufferItem is an item of experience data.
|
||||
|
||||
Shapes of each tensor:
|
||||
sequences: (S)
|
||||
action_log_probs: (A)
|
||||
values: (1)
|
||||
reward: (1)
|
||||
advatanges: (1)
|
||||
attention_mask: (S)
|
||||
action_mask: (A)
|
||||
|
||||
"A" is the number of actions.
|
||||
"""
|
||||
sequences: torch.Tensor
|
||||
action_log_probs: torch.Tensor
|
||||
values: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
advantages: torch.Tensor
|
||||
attention_mask: Optional[torch.LongTensor]
|
||||
action_mask: Optional[torch.BoolTensor]
|
||||
|
||||
|
||||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
batch_size = experience.sequences.size(0)
|
||||
batch_kwargs = [{} for _ in range(batch_size)]
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
value = getattr(experience, key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
vals = torch.unbind(value)
|
||||
else:
|
||||
# None
|
||||
vals = [value for _ in range(batch_size)]
|
||||
assert batch_size == len(vals)
|
||||
for i, v in enumerate(vals):
|
||||
batch_kwargs[i][key] = v
|
||||
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
|
||||
return items
|
||||
|
||||
|
||||
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
assert side in ('left', 'right')
|
||||
max_len = max(seq.size(0) for seq in sequences)
|
||||
padded_sequences = []
|
||||
for seq in sequences:
|
||||
pad_len = max_len - seq.size(0)
|
||||
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
|
||||
padded_sequences.append(F.pad(seq, padding))
|
||||
return torch.stack(padded_sequences, dim=0)
|
||||
|
||||
|
||||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
||||
kwargs = {}
|
||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
vals = [getattr(item, key) for item in items]
|
||||
if key in to_pad_keys:
|
||||
batch_data = zero_pad_sequences(vals)
|
||||
else:
|
||||
batch_data = torch.stack(vals, dim=0)
|
||||
kwargs[key] = batch_data
|
||||
return Experience(**kwargs)
|
|
@ -0,0 +1,5 @@
|
|||
from .base import Trainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
|
||||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer']
|
|
@ -0,0 +1,162 @@
|
|||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from chatgpt.experience_maker import Experience, ExperienceMaker
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""
|
||||
Base class for rlhf trainers.
|
||||
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
|
||||
replay_buffer (ReplayBuffer): the replay buffer to use for training
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
|
||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
experience_maker: ExperienceMaker,
|
||||
replay_buffer: ReplayBuffer,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 1,
|
||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||||
sample_replay_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.experience_maker = experience_maker
|
||||
self.replay_buffer = replay_buffer
|
||||
self.experience_batch_size = experience_batch_size
|
||||
self.max_epochs = max_epochs
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_kwargs = generate_kwargs
|
||||
self.sample_replay_buffer = sample_replay_buffer
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
elif isinstance(inputs, dict):
|
||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
|
||||
def _sample_prompts(self, prompts) -> list:
|
||||
indices = list(range(len(prompts)))
|
||||
sampled_indices = random.sample(indices, self.experience_batch_size)
|
||||
return [prompts[i] for i in sampled_indices]
|
||||
|
||||
def _learn(self):
|
||||
# replay buffer may be empty at first, we should rebuild at each training
|
||||
if not self.sample_replay_buffer:
|
||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
||||
device = torch.cuda.current_device()
|
||||
if self.sample_replay_buffer:
|
||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for _ in pbar:
|
||||
experience = self.replay_buffer.sample()
|
||||
metrics = self.training_step(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
else:
|
||||
for epoch in range(self.max_epochs):
|
||||
self._on_learn_epoch_start(epoch)
|
||||
if isinstance(dataloader.sampler, DistributedSampler):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
||||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(device)
|
||||
metrics = self.training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
pbar.set_postfix(metrics)
|
||||
self._on_learn_epoch_end(epoch)
|
||||
|
||||
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
||||
time = 0
|
||||
self._on_fit_start()
|
||||
for episode in range(num_episodes):
|
||||
self._on_episode_start(episode)
|
||||
for timestep in tqdm(range(max_timesteps),
|
||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||||
disable=not is_rank_0()):
|
||||
time += 1
|
||||
rand_prompts = self._sample_prompts(prompts)
|
||||
if self.tokenizer is not None:
|
||||
inputs = self.tokenizer(rand_prompts)
|
||||
else:
|
||||
inputs = rand_prompts
|
||||
self._on_make_experience_start()
|
||||
experience = self._make_experience(inputs)
|
||||
self._on_make_experience_end(experience)
|
||||
self.replay_buffer.append(experience)
|
||||
if time % update_timesteps == 0:
|
||||
self._learn()
|
||||
self.replay_buffer.clear()
|
||||
self._on_episode_end(episode)
|
||||
self._on_fit_end()
|
||||
|
||||
# TODO(ver217): maybe simplify these code using context
|
||||
def _on_fit_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
|
||||
def _on_make_experience_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_start()
|
||||
|
||||
def _on_make_experience_end(self, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_end(experience)
|
||||
|
||||
def _on_learn_epoch_start(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_epoch_start(epoch)
|
||||
|
||||
def _on_learn_epoch_end(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_epoch_end(epoch)
|
||||
|
||||
def _on_learn_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_start()
|
||||
|
||||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_end(metrics, experience)
|
|
@ -0,0 +1,4 @@
|
|||
from .base import Callback
|
||||
from .performance_evaluator import PerformanceEvaluator
|
||||
|
||||
__all__ = ['Callback', 'PerformanceEvaluator']
|
|
@ -0,0 +1,39 @@
|
|||
from abc import ABC
|
||||
|
||||
from chatgpt.experience_maker import Experience
|
||||
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_start(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_epoch_start(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_epoch_end(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
pass
|
|
@ -0,0 +1,133 @@
|
|||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from chatgpt.experience_maker import Experience
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class PerformanceEvaluator(Callback):
|
||||
"""
|
||||
Callback for valuate the performance of the model.
|
||||
Args:
|
||||
actor_num_params: The number of parameters of the actor model.
|
||||
critic_num_params: The number of parameters of the critic model.
|
||||
initial_model_num_params: The number of parameters of the initial model.
|
||||
reward_model_num_params: The number of parameters of the reward model.
|
||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
initial_model_num_params: int,
|
||||
reward_model_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_episodes: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
self.critic_num_params = critic_num_params
|
||||
self.initial_model_num_params = initial_model_num_params
|
||||
self.reward_model_num_params = reward_model_num_params
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_episodes = ignore_episodes
|
||||
self.disable: bool = False
|
||||
|
||||
self.make_experience_duration: float = 0.
|
||||
self.make_experience_start_time: Optional[float] = None
|
||||
self.make_experience_num_samples: int = 0
|
||||
self.make_experience_flop: int = 0
|
||||
self.learn_duration: float = 0.
|
||||
self.learn_start_time: Optional[float] = None
|
||||
self.learn_num_samples: int = 0
|
||||
self.learn_flop: int = 0
|
||||
|
||||
def on_episode_start(self, episode: int) -> None:
|
||||
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.make_experience_start_time = time()
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.make_experience_duration += time() - self.make_experience_start_time
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.make_experience_num_samples += batch_size
|
||||
|
||||
# actor generate
|
||||
num_actions = experience.action_mask.size(1)
|
||||
input_len = seq_len - num_actions
|
||||
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
||||
# actor forward
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
||||
# critic forward
|
||||
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
||||
# initial model forward
|
||||
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
||||
# reward model forward
|
||||
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
||||
|
||||
def on_learn_batch_start(self) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.learn_start_time = time()
|
||||
|
||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.learn_duration += time() - self.learn_start_time
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.learn_num_samples += batch_size
|
||||
|
||||
# actor forward-backward, 3 means forward(1) + backward(2)
|
||||
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
# critic foward-backward
|
||||
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
|
||||
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
|
||||
|
||||
avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
|
||||
avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
|
||||
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
|
||||
|
||||
print_rank_0(
|
||||
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
|
||||
)
|
||||
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')
|
|
@ -0,0 +1,114 @@
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
|
||||
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
|
||||
from chatgpt.nn.generation_utils import update_model_kwargs_fn
|
||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
|
||||
Args:
|
||||
strategy (Strategy): the strategy to use for training
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
critic (Critic): the critic model in ppo algorithm
|
||||
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
||||
initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
critic_optim (Optimizer): the optimizer to use for critic model
|
||||
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
|
||||
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
||||
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
tokenier (Callable, optional): the tokenizer to use for tokenizing the input
|
||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
actor: Actor,
|
||||
critic: Critic,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
actor_optim: Optimizer,
|
||||
critic_optim: Optimizer,
|
||||
kl_coef: float = 0.1,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
buffer_cpu_offload: bool = True,
|
||||
eps_clip: float = 0.2,
|
||||
value_clip: float = 0.4,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 1,
|
||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||||
sample_replay_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
||||
actor = Actor(strategy.setup_model(actor.model))
|
||||
critic = strategy.setup_model(critic)
|
||||
reward_model = strategy.setup_model(reward_model)
|
||||
initial_model = Actor(strategy.setup_model(initial_model.model))
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
||||
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model)
|
||||
self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic)
|
||||
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
self.strategy.optimizer_step(self.actor_optim)
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
values = self.critic(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self.critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
|
||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||
|
||||
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
|
||||
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
|
@ -0,0 +1,77 @@
|
|||
from abc import ABC
|
||||
|
||||
import loralib as lora
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import PairWiseLoss
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class RewardModelTrainer(ABC):
|
||||
"""
|
||||
Trainer to use while training reward model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): the model to train
|
||||
train_dataset (RewardDataset): the dataset to use for training
|
||||
eval_dataset (RewardDataset): the dataset to use for evaluation
|
||||
batch_size (int, defaults to 1): the batch size while training
|
||||
num_epochs (int, defaults to 2): the number of epochs to train
|
||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
train_dataset: RewardDataset,
|
||||
eval_dataset: RewardDataset,
|
||||
batch_size: int = 1,
|
||||
num_epochs: int = 2,
|
||||
optim_kwargs: dict = {'lr': 1e-4}) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
|
||||
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
|
||||
self.loss_fn = PairWiseLoss()
|
||||
self.optimizer = Adam(self.model.parameters(), **optim_kwargs)
|
||||
self.epochs = num_epochs
|
||||
|
||||
def fit(self, use_lora):
|
||||
epoch_bar = tqdm(range(self.epochs), desc='Train epoch')
|
||||
for epoch in range(self.epochs):
|
||||
step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch)
|
||||
# train
|
||||
if use_lora > 0:
|
||||
print("Using Lora")
|
||||
lora.mark_only_lora_as_trainable(self.model)
|
||||
else:
|
||||
self.model.train()
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).cuda()
|
||||
c_mask = c_mask.squeeze(1).cuda()
|
||||
reject_ids = reject_ids.squeeze(1).cuda()
|
||||
r_mask = r_mask.squeeze(1).cuda()
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
step_bar.update()
|
||||
step_bar.set_postfix({'loss': loss.item()})
|
||||
|
||||
# eval
|
||||
self.model.eval()
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
||||
dist = 0
|
||||
chosen_ids = chosen_ids.squeeze(1).cuda()
|
||||
c_mask = c_mask.squeeze(1).cuda()
|
||||
reject_ids = reject_ids.squeeze(1).cuda()
|
||||
r_mask = r_mask.squeeze(1).cuda()
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
dist += (chosen_reward - reject_reward)
|
||||
dist_mean = dist / self.eval_dataloader.__len__()
|
||||
epoch_bar.update()
|
||||
step_bar.set_postfix({'loss': loss.item(), 'dist_mean': dist_mean.item()})
|
||||
step_bar.close()
|
|
@ -0,0 +1,6 @@
|
|||
from .base import Strategy
|
||||
from .colossalai import ColossalAIStrategy
|
||||
from .ddp import DDPStrategy
|
||||
from .naive import NaiveStrategy
|
||||
|
||||
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
|
|
@ -0,0 +1,45 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Strategy(ABC):
|
||||
"""
|
||||
Base class for training strategies.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.setup_distributed()
|
||||
|
||||
@abstractmethod
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_distributed(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
pass
|
||||
|
||||
def model_init_context(self):
|
||||
return nullcontext()
|
|
@ -0,0 +1,125 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
|
||||
class ColossalAIStrategy(DDPStrategy):
|
||||
"""
|
||||
The strategy for training with ColossalAI.
|
||||
|
||||
Args:
|
||||
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
||||
seed(int): The seed for the random number generator.
|
||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
||||
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
||||
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
|
||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
||||
initial_scale(float): The initial scale for the optimizer.
|
||||
growth_factor(float): The growth factor for the optimizer.
|
||||
backoff_factor(float): The backoff factor for the optimizer.
|
||||
growth_interval(int): The growth interval for the optimizer.
|
||||
hysteresis(int): The hysteresis for the optimizer.
|
||||
min_scale(float): The minimum scale for the optimizer.
|
||||
max_scale(float): The maximum scale for the optimizer.
|
||||
max_norm(float): The maximum norm for the optimizer.
|
||||
norm_type(float): The norm type for the optimizer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage: int = 3,
|
||||
seed: int = 42,
|
||||
shard_init: bool = True, # only for stage 3
|
||||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_mb: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_mb: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0) -> None:
|
||||
super().__init__(seed)
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
self.stage = stage
|
||||
self.shard_init = shard_init
|
||||
self.gemini_config = dict(device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_mb=search_range_mb,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
if stage == 3:
|
||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
else:
|
||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'))
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
if self.stage == 3:
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
return ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_pg=shard_pg,
|
||||
default_dist_spec=default_dist_spec)
|
||||
return super().model_init_context()
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
||||
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.backward(loss)
|
||||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from .naive import NaiveStrategy
|
||||
|
||||
|
||||
class DDPStrategy(NaiveStrategy):
|
||||
"""
|
||||
Strategy for distributed training using torch.distributed.
|
||||
"""
|
||||
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
self.seed = seed
|
||||
super().__init__()
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
self.set_seed(self.seed)
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
def set_seed(self, seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
device = torch.cuda.current_device()
|
||||
return DDP(model, device_ids=[device])
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
sampler = DistributedSampler(replay_buffer,
|
||||
num_replicas=dist.get_world_size(),
|
||||
rank=dist.get_rank(),
|
||||
shuffle=True,
|
||||
seed=self.seed,
|
||||
drop_last=True)
|
||||
return DataLoader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .base import Strategy
|
||||
|
||||
|
||||
class NaiveStrategy(Strategy):
|
||||
"""
|
||||
Strategy for single GPU. No parallelism is used.
|
||||
"""
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
loss.backward()
|
||||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
pass
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
return optimizer
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return DataLoader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
|
@ -0,0 +1,5 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
|
@ -0,0 +1,105 @@
|
|||
# Examples
|
||||
|
||||
## Install requirements
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Train with dummy prompt data
|
||||
|
||||
This script supports 3 strategies:
|
||||
|
||||
- naive
|
||||
- ddp
|
||||
- colossalai
|
||||
|
||||
It uses random generated prompt data.
|
||||
|
||||
Naive strategy only support single GPU training:
|
||||
|
||||
```shell
|
||||
python train_dummy.py --strategy naive
|
||||
# display cli help
|
||||
python train_dummy.py -h
|
||||
```
|
||||
|
||||
DDP strategy and ColossalAI strategy support multi GPUs training:
|
||||
|
||||
```shell
|
||||
# run DDP on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
|
||||
# run ColossalAI on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai
|
||||
```
|
||||
|
||||
## Train with real prompt data
|
||||
|
||||
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
|
||||
|
||||
You should download `prompts.csv` first.
|
||||
|
||||
This script also supports 3 strategies.
|
||||
|
||||
```shell
|
||||
# display cli help
|
||||
python train_dummy.py -h
|
||||
# run naive on 1 GPU
|
||||
python train_prompts.py prompts.csv --strategy naive
|
||||
# run DDP on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp
|
||||
# run ColossalAI on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai
|
||||
```
|
||||
|
||||
## Train the reward model
|
||||
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
|
||||
|
||||
You can download the dataset from huggingface automatically.
|
||||
|
||||
Use these code to train your reward model.
|
||||
|
||||
```shell
|
||||
# Naive reward model training
|
||||
python train_reward_model.py --pretrain <your model path>
|
||||
# if to use LoRA
|
||||
python train_reward_model.py --pretrain <your model path> --lora_rank 16
|
||||
```
|
||||
|
||||
## Support Model
|
||||
|
||||
### GPT
|
||||
- [ ] GPT2-S (s)
|
||||
- [ ] GPT2-M (m)
|
||||
- [ ] GPT2-L (l)
|
||||
- [ ] GPT2-XL (xl)
|
||||
- [ ] GPT2-4B (4b)
|
||||
- [ ] GPT2-6B (6b)
|
||||
- [ ] GPT2-8B (8b)
|
||||
- [ ] GPT2-10B (10b)
|
||||
- [ ] GPT2-12B (12b)
|
||||
- [ ] GPT2-15B (15b)
|
||||
- [ ] GPT2-18B (18b)
|
||||
- [ ] GPT2-20B (20b)
|
||||
- [ ] GPT2-24B (24b)
|
||||
- [ ] GPT2-28B (28b)
|
||||
- [ ] GPT2-32B (32b)
|
||||
- [ ] GPT2-36B (36b)
|
||||
- [ ] GPT2-40B (40b)
|
||||
- [ ] GPT3 (175b)
|
||||
|
||||
### BLOOM
|
||||
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
|
||||
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
|
||||
- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||
- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
|
||||
- [ ] BLOOM-175b
|
||||
|
||||
### OPT
|
||||
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
|
||||
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
|
||||
- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
|
||||
- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
|
||||
- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
|
||||
- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
|
||||
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
|
|
@ -0,0 +1 @@
|
|||
pandas>=1.4.1
|
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -xue
|
||||
|
||||
if [ -z "$PROMPT_PATH" ]; then
|
||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
pip install -r ${BASE}/requirements.txt
|
||||
|
||||
# train dummy
|
||||
python ${BASE}/train_dummy.py --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
||||
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
||||
done
|
||||
|
||||
# train prompts
|
||||
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
|
||||
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
||||
done
|
|
@ -0,0 +1,112 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def preprocess_batch(samples):
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor().cuda()
|
||||
critic = OPTCritic().cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--num_episodes', type=int, default=50)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,18 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
|
|
@ -0,0 +1,112 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(lora_rank=args.lora_rank).cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(dataset,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt_path')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,18 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
|
|
@ -0,0 +1,53 @@
|
|||
import argparse
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import BLOOMRM
|
||||
from chatgpt.trainer import RewardModelTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import BloomTokenizerFast
|
||||
|
||||
|
||||
def train(args):
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = BLOOMRM(pretrained=args.pretrain)
|
||||
|
||||
model.cuda()
|
||||
|
||||
max_len = 1024
|
||||
|
||||
# prepare for data and dataset
|
||||
data = load_dataset(args.dataset)
|
||||
train_data = data["train"]
|
||||
eval_data = data['test']
|
||||
train_dataset = RewardDataset(train_data, tokenizer, max_len)
|
||||
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
|
||||
|
||||
# batch_size here is expected to be C(k,2), k means # response of each prompt
|
||||
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit(use_lora=args.lora_rank)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
|
||||
else:
|
||||
torch.save(trainer.model, args.save_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||
parser.add_argument('--max_epochs', type=int, default=2)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
|
@ -0,0 +1,18 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
|
|
@ -0,0 +1,6 @@
|
|||
[pytest]
|
||||
markers =
|
||||
cpu: tests which can run on CPU
|
||||
gpu: tests which requires a single GPU
|
||||
dist: tests which are run in a multi-GPU or multi-machine environment
|
||||
experiment: tests for experimental features
|
|
@ -0,0 +1 @@
|
|||
pytest
|
|
@ -0,0 +1,6 @@
|
|||
transformers>=4.20.1
|
||||
tqdm
|
||||
datasets
|
||||
loralib
|
||||
colossalai>=0.2.4
|
||||
torch
|
|
@ -0,0 +1,41 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, 'r') as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def fetch_readme():
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def fetch_version():
|
||||
with open('version.txt', 'r') as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
setup(
|
||||
name='chatgpt',
|
||||
version=fetch_version(),
|
||||
packages=find_packages(exclude=(
|
||||
'tests',
|
||||
'benchmarks',
|
||||
'*.egg-info',
|
||||
)),
|
||||
description='A RLFH implementation (ChatGPT) powered by ColossalAI',
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
license='Apache Software License 2.0',
|
||||
url='https://github.com/hpcaitech/ChatGPT',
|
||||
install_requires=fetch_requirements('requirements.txt'),
|
||||
python_requires='>=3.6',
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Environment :: GPU :: NVIDIA CUDA',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: System :: Distributed Computing',
|
||||
],
|
||||
)
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from chatgpt.experience_maker import NaiveExperienceMaker
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
||||
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def gather_and_equal(tensor: torch.Tensor) -> bool:
|
||||
world_size = dist.get_world_size()
|
||||
outputs = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(outputs, tensor.contiguous())
|
||||
for t in outputs[1:]:
|
||||
if not torch.equal(outputs[0], t):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_test_data(strategy):
|
||||
EXPERINCE_BATCH_SIZE = 4
|
||||
SAMPLE_BATCH_SIZE = 2
|
||||
|
||||
if strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai':
|
||||
strategy = ColossalAIStrategy(placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
|
||||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model)).cuda()
|
||||
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
|
||||
replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
||||
|
||||
# experience of all ranks should be the same
|
||||
for _ in range(2):
|
||||
data = get_data(EXPERINCE_BATCH_SIZE)
|
||||
assert gather_and_equal(data['input_ids'])
|
||||
assert gather_and_equal(data['attention_mask'])
|
||||
experience = experience_maker.make_experience(**data,
|
||||
do_sample=True,
|
||||
max_length=16,
|
||||
eos_token_id=50256,
|
||||
pad_token_id=50256)
|
||||
assert gather_and_equal(experience.sequences)
|
||||
assert gather_and_equal(experience.action_log_probs)
|
||||
assert gather_and_equal(experience.values)
|
||||
assert gather_and_equal(experience.reward)
|
||||
assert gather_and_equal(experience.advantages)
|
||||
assert gather_and_equal(experience.action_mask)
|
||||
assert gather_and_equal(experience.attention_mask)
|
||||
replay_buffer.append(experience)
|
||||
|
||||
# replay buffer's data should be the same
|
||||
buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
|
||||
assert gather_and_equal(buffer_size)
|
||||
for item in replay_buffer.items:
|
||||
assert gather_and_equal(item.sequences)
|
||||
assert gather_and_equal(item.action_log_probs)
|
||||
assert gather_and_equal(item.values)
|
||||
assert gather_and_equal(item.reward)
|
||||
assert gather_and_equal(item.advantages)
|
||||
assert gather_and_equal(item.action_mask)
|
||||
assert gather_and_equal(item.attention_mask)
|
||||
|
||||
# dataloader of each rank should have the same size and different batch
|
||||
dataloader = strategy.setup_dataloader(replay_buffer)
|
||||
dataloader_size = torch.tensor([len(dataloader)], device='cuda')
|
||||
assert gather_and_equal(dataloader_size)
|
||||
for experience in dataloader:
|
||||
assert not gather_and_equal(experience.sequences)
|
||||
assert not gather_and_equal(experience.action_log_probs)
|
||||
assert not gather_and_equal(experience.values)
|
||||
assert not gather_and_equal(experience.reward)
|
||||
assert not gather_and_equal(experience.advantages)
|
||||
# action mask and attention mask may be same
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, strategy):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
run_test_data(strategy)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_data(world_size, strategy):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_data(2, 'colossalai')
|
|
@ -0,0 +1 @@
|
|||
0.1.0
|
|
@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
|
|||
torch.nn.ReLU,
|
||||
torch.nn.Softmax,
|
||||
]
|
||||
|
||||
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
|
||||
# This list could be extended if any other method has the same
|
||||
# argument style as view and reshape.
|
||||
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]
|
||||
|
|
|
@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import SHAPE_ARGUMENT_OPS
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
|
@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size],
|
|||
return size
|
||||
|
||||
|
||||
def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
"""
|
||||
This method is used to stick the solution strategy to the nodes and add the information
|
||||
required in runtime into graph as placeholder nodes.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
# TODO: In future PR, strategies_constructor should be a required argument,
|
||||
# instead of optional argument. This is because we don't need to consider nodes with
|
||||
# no strategy in runtime preparation pass.
|
||||
if strategies_constructor is not None:
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
||||
else:
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
no_strategy_nodes = []
|
||||
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
||||
|
||||
# the dict to get origin sharding spec of node
|
||||
origin_node_sharding_spec_dict = {}
|
||||
|
@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
|||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||
|
||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
|
@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
|||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
||||
|
||||
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
|
||||
need to be converted to the size of original tensor and managed by the users, such as torch.view,
|
||||
|
@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
nodes = tuple(mod_graph.nodes)
|
||||
node_pairs = {}
|
||||
|
||||
# DeviceMesh information instructs the scaling of the size value
|
||||
device_mesh_info = {}
|
||||
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
||||
device_mesh_info[dim] = dim_size
|
||||
|
||||
def _extract_target_dim(node):
|
||||
'''
|
||||
A helper function to etract the target dimension from size node.
|
||||
There are two usages of torch.Tensor.size:
|
||||
1. tensor.size()
|
||||
2. tensor.size(dim)
|
||||
|
||||
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
|
||||
Otherwise, the output will be in type of torch.Size and this function will return None.
|
||||
'''
|
||||
target_dim = None
|
||||
if len(node.args) > 1:
|
||||
target_dim = node.args[1]
|
||||
if target_dim < 0:
|
||||
target_dim += node.args[0]._meta_data.dim()
|
||||
return target_dim
|
||||
|
||||
def _post_processing(node, size_processing_node):
|
||||
'''
|
||||
This function is used to process the dependency between the size node and its users after
|
||||
inserting the size_process_node.
|
||||
'''
|
||||
# store original node and processing node pair in node_pairs dictioanry
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == size_processing_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_args[new_args.index(node)] = size_processing_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_kwargs[str(node)] = size_processing_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
def _update_slice_object_args(slice_object):
|
||||
'''
|
||||
This function is used to update the slice object argument list.
|
||||
If the slice object contains the Node argument, then the size node will be replaced with
|
||||
'''
|
||||
if isinstance(slice_object, slice):
|
||||
start = slice_object.start
|
||||
stop = slice_object.stop
|
||||
step = slice_object.step
|
||||
if start in node_pairs:
|
||||
start = node_pairs[start]
|
||||
if stop in node_pairs:
|
||||
stop = node_pairs[stop]
|
||||
if step in node_pairs:
|
||||
step = node_pairs[step]
|
||||
return slice(start, stop, step)
|
||||
elif isinstance(slice_object, int):
|
||||
if slice_object in node_pairs:
|
||||
return node_pairs[slice_object]
|
||||
else:
|
||||
return slice_object
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
|
||||
|
||||
for node in nodes:
|
||||
|
||||
if node.op == 'call_method' and node.target == 'size':
|
||||
|
@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
sharding_spec = node.args[0].sharding_spec
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# there are two usages of torch.Tensor.size:
|
||||
# tensor.size()
|
||||
# tensor.size(dim)
|
||||
# if a target_dim is assigned, then the output will be
|
||||
# in type of int, instead of torch.Size
|
||||
target_dim = None
|
||||
if len(node.args) > 1:
|
||||
target_dim = node.args[1]
|
||||
if target_dim < 0:
|
||||
target_dim += node.args[0]._meta_data.dim()
|
||||
|
||||
# DeviceMesh information instructs the scaling of the size value
|
||||
device_mesh_info = {}
|
||||
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
||||
device_mesh_info[dim] = dim_size
|
||||
target_dim = _extract_target_dim(node)
|
||||
|
||||
# insert size_processing node
|
||||
with mod_graph.inserting_after(node):
|
||||
size_processing_node = mod_graph.create_node('call_function',
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info,
|
||||
target_dim, node.name))
|
||||
# store original node and processing node pair in node_pairs dictioanry
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == size_processing_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_args[new_args.index(node)] = size_processing_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_kwargs[str(node)] = size_processing_node
|
||||
user.kwargs = new_kwargs
|
||||
_post_processing(node, size_processing_node)
|
||||
|
||||
if node.op == 'call_function' and node.target == operator.getitem:
|
||||
|
||||
|
@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# In this pass, we need process the last two cases because
|
||||
# node arguments may potentially appear in these cases.
|
||||
if isinstance(getitem_index, slice):
|
||||
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
|
||||
if getitem_index.start in node_pairs:
|
||||
new_start = node_pairs[getitem_index.start]
|
||||
elif getitem_index.stop in node_pairs:
|
||||
new_stop = node_pairs[getitem_index.stop]
|
||||
elif getitem_index.step in node_pairs:
|
||||
new_step = node_pairs[getitem_index.step]
|
||||
new_slice_item = slice(new_start, new_stop, new_step)
|
||||
new_slice_item = _update_slice_object_args(getitem_index)
|
||||
new_args = (node.args[0], new_slice_item)
|
||||
node.args = new_args
|
||||
|
||||
|
@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
if slice_item is None:
|
||||
new_slice_items.append(None)
|
||||
continue
|
||||
|
||||
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
|
||||
|
||||
if slice_item.start in node_pairs:
|
||||
new_start = node_pairs[slice_item.start]
|
||||
elif slice_item.stop in node_pairs:
|
||||
new_stop = node_pairs[slice_item.stop]
|
||||
elif slice_item.step in node_pairs:
|
||||
new_step = node_pairs[slice_item.step]
|
||||
new_slice_item = slice(new_start, new_stop, new_step)
|
||||
new_slice_item = _update_slice_object_args(slice_item)
|
||||
new_slice_items.append(new_slice_item)
|
||||
|
||||
new_args = (node.args[0], tuple(new_slice_items))
|
||||
|
@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
return gm
|
||||
|
||||
|
||||
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
This pass will process node args to adapt the distributed tensor layout.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
def _extract_info_from_sharding_spec(sharding_spec):
|
||||
'''
|
||||
This function is used to extract the dim_partition_dict and device_mesh from
|
||||
sharding spec instance or a list of sharding spec.
|
||||
'''
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
for element in sharding_spec:
|
||||
dim_partition_dict.append(_extract_info_from_sharding_spec(element))
|
||||
return dim_partition_dict, sharding_spec
|
||||
|
||||
def _process_node_arguments(node):
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
# There are two args style:
|
||||
# 1. (input, *shape)
|
||||
# 2. (input, shape)
|
||||
# We will extract the elements from shape and add them into the new_args
|
||||
# Finally, the args style of new_args will be unified to (input, *shape)
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (tuple, list)):
|
||||
new_args.extend(arg._meta_data)
|
||||
elif isinstance(arg._meta_data, int):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg,
|
||||
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||
if isinstance(arg, (tuple, list)):
|
||||
new_args.extend(arg)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
|
||||
new_args = _process_node_arguments(node)
|
||||
if node.op == 'call_method':
|
||||
args_to_process = list(new_args[1:])
|
||||
else:
|
||||
args_to_process = list(new_args)
|
||||
for dim, shard_dims in dim_partition_dict.items():
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
|
||||
# we will skip the dim with -1 value
|
||||
if args_to_process[dim] == -1:
|
||||
continue
|
||||
else:
|
||||
# TODO: add assertion here to make sure the dim size is divisible by total_shard_size
|
||||
args_to_process[dim] //= total_shard_size
|
||||
|
||||
args_to_process = tuple(args_to_process)
|
||||
|
||||
if node.op == 'call_method':
|
||||
new_args = (new_args[0],) + args_to_process
|
||||
else:
|
||||
new_args = args_to_process
|
||||
|
||||
node.args = new_args
|
||||
|
||||
def _filter_node_with_shape_args(node):
|
||||
if node.op == 'call_method':
|
||||
target = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
else:
|
||||
target = None
|
||||
|
||||
if target in SHAPE_ARGUMENT_OPS:
|
||||
return True
|
||||
return False
|
||||
|
||||
for node in nodes:
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
continue
|
||||
|
||||
def _process_sharding_spec(sharding_spec):
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
for element in sharding_spec:
|
||||
dim_partition_dict.append(_process_sharding_spec(element))
|
||||
return dim_partition_dict, sharding_spec
|
||||
|
||||
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
|
||||
new_args = []
|
||||
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
# process the node with (input, *shape) style args
|
||||
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
||||
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (int, tuple, list)):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
# There are two ways to use torch.view:
|
||||
# 1. torch.view(input, *shape)
|
||||
# 2. torch.view(input, shape)
|
||||
if isinstance(new_args[1], int):
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[dim + 1] == -1:
|
||||
continue
|
||||
else:
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
else:
|
||||
new_args[1] = list(new_args[1])
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[1][dim] == -1:
|
||||
continue
|
||||
else:
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
# process the node with (input, torch.Size) style args
|
||||
if target in (torch.reshape,):
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (tuple, list)):
|
||||
new_args.append(list(arg._meta_data))
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
|
||||
new_args.append(list(arg))
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[1][dim] == -1:
|
||||
continue
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
|
||||
if _filter_node_with_shape_args(node):
|
||||
_scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node)
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
|
||||
def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
|
||||
"""
|
||||
Apply the sharding action to the module parameters and buffers following the
|
||||
instructions of solver solution.
|
||||
|
@ -361,6 +386,50 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
|||
nodes = tuple(mod_graph.nodes)
|
||||
# This stream is created for overlaping the communication and computation.
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
|
||||
def _add_hook_for_grad_communication(node, param):
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
|
||||
def _filter_param_to_hook(node, op_data, comm_action):
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
|
||||
return True
|
||||
if node.op == 'get_attr' and isinstance(
|
||||
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
return True
|
||||
return False
|
||||
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if _filter_param_to_hook(node, operation_data, comm_action):
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
else:
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
|
||||
|
||||
def _shard_param(param, target_sharding_spec):
|
||||
# apply the sharding spec of parameters
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
return param
|
||||
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
|
@ -370,35 +439,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
|||
setattr(target_module, 'processed', True)
|
||||
for name, param in target_module.named_parameters():
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
# apply the sharding spec of parameters
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
|
||||
param = _shard_param(param, target_sharding_spec)
|
||||
|
||||
setattr(target_module, name, param)
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
else:
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
|
||||
_add_hook_for_grad_communication(node, param)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
|
@ -426,36 +470,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
|||
target = getattr(target_module, atoms[-1])
|
||||
|
||||
target_sharding_spec = node.sharding_spec
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
||||
setattr(target, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
|
||||
target = _shard_param(target, target_sharding_spec)
|
||||
|
||||
assert hasattr(target_module, atoms[-1])
|
||||
setattr(target_module, atoms[-1], target)
|
||||
_add_hook_for_grad_communication(node, target)
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
else:
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
|
||||
return gm
|
||||
|
||||
|
||||
|
@ -469,14 +489,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
|||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor = None,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap=False):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
|
||||
gm, solution, strategies_constructor)
|
||||
gm = _size_value_converting(gm, device_mesh)
|
||||
gm = _node_args_converting(gm, device_mesh)
|
||||
gm = size_value_converting_pass(gm, device_mesh)
|
||||
gm = node_args_converting_pass(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
# gm = implicit_comm_action_apply(gm)
|
||||
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
|
||||
gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap)
|
||||
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .options import SolverOptions
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .solver import Solver
|
||||
from .strategies_constructor import StrategiesConstructor
|
|
@ -1,142 +0,0 @@
|
|||
import functools
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict.
|
||||
|
||||
|
||||
Args:
|
||||
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
|
||||
if isinstance(input_, Node):
|
||||
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
|
||||
meta_tensor = input_._meta_data
|
||||
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
||||
shape = meta_tensor.shape
|
||||
elif isinstance(input_, torch.Tensor):
|
||||
shape = input_.shape
|
||||
else:
|
||||
raise TypeError(
|
||||
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
||||
)
|
||||
for dim_index, sharding_index_list in dim_partition_dict.items():
|
||||
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def generate_resharding_costs(nodes: List[Node],
|
||||
sharding_specs: List[ShardingSpec],
|
||||
count_backward: Optional[bool] = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
index=None):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
nodes (List[Node]): a list of nodes
|
||||
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
if not isinstance(input_sharding_spec, ShardingSpec):
|
||||
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
|
||||
input_sharding_spec = input_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
try:
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
resharding_cost = INFINITY_COST
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
|
||||
def ignore_sharding_exception(func):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||
total_sharding_size = 1
|
||||
for mesh_dim_list in dim_partition_dict.values():
|
||||
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||
total_sharding_size *= sharding_size
|
||||
|
||||
return total_sharding_size
|
|
@ -1,83 +0,0 @@
|
|||
import torch
|
||||
import operator
|
||||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
|
||||
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
ELEMENTWISE_METHOD_OP = [
|
||||
torch.Tensor.to,
|
||||
torch.Tensor.type,
|
||||
# TODO: contiguous maybe need some extra processes.
|
||||
torch.Tensor.contiguous
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||
RESHAPE_METHOD_OP = [
|
||||
torch.Tensor.view,
|
||||
torch.Tensor.unsqueeze,
|
||||
torch.Tensor.split,
|
||||
torch.Tensor.permute,
|
||||
torch.Tensor.transpose,
|
||||
]
|
||||
BCAST_FUNC_OP = [
|
||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d
|
||||
]
|
||||
CONV_FUNC_OP = [
|
||||
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
||||
]
|
||||
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
|
||||
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||
NON_PARAM_FUNC_OP = [
|
||||
torch.flatten,
|
||||
torch.reshape,
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
torch.flatten,
|
||||
torch.where,
|
||||
operator.pow,
|
||||
torch.pow,
|
||||
torch.tanh,
|
||||
torch.add,
|
||||
torch.sub,
|
||||
torch.mul,
|
||||
torch.div,
|
||||
torch.floor_divide,
|
||||
torch.true_divide,
|
||||
operator.add,
|
||||
operator.sub,
|
||||
operator.mul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
|
||||
INFINITY_COST = 1e13
|
|
@ -1,174 +0,0 @@
|
|||
import math
|
||||
from typing import List
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
||||
'''
|
||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||
|
||||
Argument:
|
||||
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
|
||||
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
|
||||
'''
|
||||
|
||||
def __init__(self, leaf_strategies, simplify=True):
|
||||
self.leaf_strategies = leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
# stores number of strategies in each node
|
||||
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
|
||||
# extra_node_costs will store the extra costs introduced by merging nodes
|
||||
self.extra_node_costs = {}
|
||||
self.following_dict = {}
|
||||
self.simplify = simplify
|
||||
self._build_cost_graph()
|
||||
|
||||
def _remove_invalid_node(self, node, attr_name):
|
||||
remove_list = []
|
||||
target_node_list = getattr(node, attr_name, [])
|
||||
for target_node in target_node_list:
|
||||
if target_node not in self.nodes:
|
||||
remove_list.append(target_node)
|
||||
for element in remove_list:
|
||||
target_node_list.remove(element)
|
||||
|
||||
def _build_cost_graph(self):
|
||||
'''
|
||||
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||
set to node.
|
||||
'''
|
||||
self.edge_costs = {}
|
||||
if self.simplify:
|
||||
self.merge_pair = []
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
# build edge_cost
|
||||
dst_node = strategies_vector.node
|
||||
for src_node in strategies_vector.predecessor_nodes:
|
||||
if src_node not in self.nodes:
|
||||
continue
|
||||
node_pair = (src_node, dst_node)
|
||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(src_node.strategies_vector)):
|
||||
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||
self._remove_invalid_node(dst_node, 'parents')
|
||||
self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
|
||||
def get_edge_cost(self, src_node, dst_node):
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
||||
def merge_node(self, src_node, dst_node):
|
||||
'''
|
||||
To merge dst_node into src_node, we need to do it in following steps:
|
||||
|
||||
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||
of src_node to merge, it is important because the logical resharding costs
|
||||
between the parents node of src_node and merged node depend on the src_node
|
||||
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
||||
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
||||
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
||||
|
||||
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
||||
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
||||
another is the origin extra costs in src_node strategy.
|
||||
|
||||
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
|
||||
detached from it.
|
||||
|
||||
Argument:
|
||||
src_node(Node): The node will be merged into dst_node.
|
||||
dst_node(Node): The node to integrate src_node.
|
||||
'''
|
||||
src_node_index = dst_node.parents.index(src_node)
|
||||
# build merge_map
|
||||
merge_map = {}
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
min_cost = INFINITY_COST
|
||||
lowest_cost_index = -1
|
||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
|
||||
if resharding_cost <= min_cost:
|
||||
min_cost = resharding_cost
|
||||
lowest_cost_index = dst_index
|
||||
merge_map[src_index] = lowest_cost_index
|
||||
|
||||
# extra_node_cost for src node
|
||||
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
target_strate_index = merge_map[src_index]
|
||||
target_strategy = dst_node.strategies_vector[target_strate_index]
|
||||
self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
|
||||
if dst_node in self.extra_node_costs:
|
||||
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
|
||||
|
||||
# add new node pair to cost graph
|
||||
for child_node in dst_node.children:
|
||||
new_node_pair = (src_node, child_node)
|
||||
old_node_pair = (dst_node, child_node)
|
||||
if new_node_pair in self.edge_costs:
|
||||
continue
|
||||
edge_cost = {}
|
||||
for i in range(self.node_lens[src_node]):
|
||||
for j in range(self.node_lens[child_node]):
|
||||
dst_strate_index = merge_map[i]
|
||||
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
|
||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
||||
if new_node_pair not in self.edge_costs:
|
||||
self.edge_costs[new_node_pair] = edge_cost
|
||||
else:
|
||||
# we should accumulate the resharding costs if args of child node contain
|
||||
# both src node and dst node.
|
||||
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
|
||||
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
|
||||
|
||||
# connect src node and children of dst node
|
||||
dst_node.parents.remove(src_node)
|
||||
src_node.children.remove(dst_node)
|
||||
self.edge_costs.pop((src_node, dst_node))
|
||||
for child_node in dst_node.children:
|
||||
if child_node not in src_node.children:
|
||||
src_node.children.append(child_node)
|
||||
if src_node not in child_node.parents:
|
||||
child_node.parents.append(src_node)
|
||||
# remove dst node from cost graph when dst node has no producer.
|
||||
if len(dst_node.parents) == 0:
|
||||
child_node.parents.remove(dst_node)
|
||||
node_pair = (dst_node, child_node)
|
||||
self.edge_costs.pop(node_pair)
|
||||
if len(dst_node.parents) == 0:
|
||||
self.following_dict[dst_node] = src_node
|
||||
dst_node.children = []
|
||||
|
||||
def _reindexing_src(self, src):
|
||||
if src not in self.following_dict:
|
||||
return src
|
||||
return self._reindexing_src(self.following_dict[src])
|
||||
|
||||
def simplify_graph(self):
|
||||
if not self.simplify:
|
||||
return
|
||||
self.merge_pair.reverse()
|
||||
for (src_node, dst_node) in self.merge_pair:
|
||||
self.merge_node(src_node, dst_node)
|
||||
self.merge_pair.reverse()
|
||||
reindexing_following_dict = {}
|
||||
for dst, src in self.following_dict.items():
|
||||
reindexing_following_dict[dst] = self._reindexing_src(src)
|
||||
self.following_dict = reindexing_following_dict
|
|
@ -1,165 +0,0 @@
|
|||
from collections import OrderedDict as ODict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, OrderedDict, Union
|
||||
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.utils import get_node_module
|
||||
|
||||
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveVariable:
|
||||
"""
|
||||
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||||
"""
|
||||
name: str
|
||||
node: Node
|
||||
is_inplace: bool
|
||||
|
||||
|
||||
class LiveVariableVector(list):
|
||||
"""
|
||||
LiveVariableVector is a data structure to store the list of LiveVariable objects.
|
||||
"""
|
||||
|
||||
def exists(self, name) -> bool:
|
||||
"""
|
||||
Check if a variable has already existed in the current list by name.
|
||||
"""
|
||||
for var in self:
|
||||
if name == var.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name) -> LiveVariable:
|
||||
for var in self:
|
||||
if name == var.name:
|
||||
return var
|
||||
raise KeyError(f"Variable {name} is not found")
|
||||
|
||||
def copy(self) -> "LiveVariableVector":
|
||||
"""
|
||||
Create a copy of this vector
|
||||
"""
|
||||
vector = LiveVariableVector()
|
||||
for var in self:
|
||||
vector.append(var)
|
||||
return vector
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveStage:
|
||||
"""
|
||||
LiveStage is a data structure to record the living variables at this current node.
|
||||
"""
|
||||
name: str
|
||||
node: Node
|
||||
all_live_vars: LiveVariableVector
|
||||
unique_live_vars: LiveVariableVector
|
||||
|
||||
|
||||
class GraphAnalyser:
|
||||
|
||||
def __init__(self, gm: GraphModule):
|
||||
self._gm = gm
|
||||
self._graph = gm.graph
|
||||
|
||||
@property
|
||||
def gm(self) -> GraphModule:
|
||||
"""
|
||||
Return the GraphModule object associated with this analyser.
|
||||
"""
|
||||
return self._gm
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
"""
|
||||
Return the Graph object associated with this analyser.
|
||||
"""
|
||||
return self._graph
|
||||
|
||||
def liveness_analysis(self) -> List[LiveStage]:
|
||||
"""
|
||||
Analyse the graph to obtain the variable liveness information. This function returns
|
||||
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
||||
"""
|
||||
compute_nodes = self.graph.nodes
|
||||
liveness_list = []
|
||||
|
||||
# checked: record all variables created since the first stage
|
||||
# all: record the live variables only exist until the current stage.
|
||||
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
|
||||
# unique: record the unique live variables only exist until the current stage.
|
||||
# this is different from `all list` as some variables are duplicated.
|
||||
checked_variables = LiveVariableVector()
|
||||
all_live_variables = LiveVariableVector()
|
||||
unique_live_vars = LiveVariableVector()
|
||||
|
||||
for idx, node in enumerate(compute_nodes):
|
||||
#############################
|
||||
# find new living variables #
|
||||
#############################
|
||||
# detect whether the current op is an in-place op
|
||||
# if it is an in-place op, we would deem it as a duplciate var
|
||||
is_inplace = False
|
||||
if node.op == 'call_function':
|
||||
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||||
if node.kwargs.get('inplace', False):
|
||||
is_inplace = True
|
||||
elif node.op == 'call_module':
|
||||
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
|
||||
module = get_node_module(node)
|
||||
if getattr(module, 'inplace', False):
|
||||
is_inplace = True
|
||||
|
||||
# add the output var
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
|
||||
if not is_inplace:
|
||||
unique_live_vars.append(live_var)
|
||||
checked_variables.append(live_var)
|
||||
all_live_variables.append(live_var)
|
||||
|
||||
# check if any input is not checked yet
|
||||
for arg in node.args:
|
||||
if not isinstance(arg, Node):
|
||||
continue
|
||||
arg_name = arg.name
|
||||
if not checked_variables.exists(arg_name):
|
||||
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
|
||||
all_live_variables.append(live_var_from_arg)
|
||||
checked_variables.append(live_var_from_arg)
|
||||
unique_live_vars.append(live_var_from_arg)
|
||||
|
||||
# TODO: add the logic to remove live variables
|
||||
# this should be completed if we are able to trace the backward compute graph
|
||||
|
||||
# add this stage to liveness dict
|
||||
stage = LiveStage(name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy())
|
||||
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||||
replace = False
|
||||
for index, prev_stage in enumerate(liveness_list):
|
||||
all_covered = True
|
||||
for ele in prev_stage.unique_live_vars:
|
||||
if ele not in stage.unique_live_vars:
|
||||
all_covered = False
|
||||
break
|
||||
if all_covered:
|
||||
replace = True
|
||||
break
|
||||
if replace:
|
||||
liveness_list[index] = stage
|
||||
else:
|
||||
liveness_list.append(stage)
|
||||
|
||||
return liveness_list
|
||||
|
||||
def get_alias_set(self):
|
||||
pass
|
|
@ -1,15 +0,0 @@
|
|||
from .batch_norm_handler import BatchNormHandler
|
||||
from .bcast_op_handler import BcastOpHandler
|
||||
from .conv_handler import ConvHandler
|
||||
from .dot_handler import DotHandler
|
||||
from .embedding_handler import EmbeddingHandler
|
||||
from .layer_norm_handler import LayerNormHandler
|
||||
from .operator_handler import OperatorHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from .where_handler import WhereHandler
|
||||
|
||||
__all__ = [
|
||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||
'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler'
|
||||
]
|
|
@ -1,492 +0,0 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
||||
class BatchNormHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||
|
||||
To keep the math consistency, there are two way to do BatchNorm if the input
|
||||
shards on batch dimension:
|
||||
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
|
||||
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
|
||||
us to keep the computing correctness.
|
||||
In this handler, both methods will be considered.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.bias = self.module_named_parameters['bias']
|
||||
self.output_data = self.node._meta_data
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
|
||||
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# TODO: a constant coefficient need to be added.
|
||||
# 1D: (L) * N * Cin
|
||||
# 2D: (H * W) * N * Cin
|
||||
# 3D: (H * W * D) * N * Cin
|
||||
|
||||
input_size = self.input_data.shape[2:]
|
||||
input_size_product = reduce(operator.mul, input_size, 1)
|
||||
forward_compute_cost = input_size_product * bs * channel_in
|
||||
backward_activation_compute_cost = input_size_product * bs * channel_in
|
||||
backward_weight_compute_cost = input_size_product * bs * channel_in
|
||||
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||
compute_cost = forward_compute_cost + backward_compute_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = numel_output
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
|
||||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
# the memory cost need to be recomputed
|
||||
# compute the memroy cost of new strategy
|
||||
new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# the communication cost need to count the sharding cost into this strategy
|
||||
# compute the communication cost of new strategy
|
||||
origin_communication_cost = communication_cost
|
||||
tiny_shard_cost = 10
|
||||
new_forward_communication_cost = tiny_shard_cost
|
||||
# we need to all gather the batch dimension for the basic strategy
|
||||
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
|
||||
|
||||
sharding_strategies = ShardingStrategy(new_name,
|
||||
output_sharding_spec=new_sharding_spec_for_output,
|
||||
compute_cost=new_compute_cost,
|
||||
communication_cost=new_communication_cost,
|
||||
memory_cost=new_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
|
||||
self.device_mesh.shape[mesh_dim_1])
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
# the memory cost need to be recomputed
|
||||
new_sharding_size_input = 1
|
||||
for mesh_dim in mesh_dim_list:
|
||||
new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
|
||||
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# the communication cost need to count the sharding cost into this strategy
|
||||
origin_communication_cost = communication_cost
|
||||
tiny_shard_cost = 10
|
||||
new_forward_communication_cost = tiny_shard_cost
|
||||
if len(mesh_dim_list) == 1:
|
||||
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
|
||||
mesh_dim_list[0])
|
||||
else:
|
||||
new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
|
||||
memory_cost_backward_activation, 0)
|
||||
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
|
||||
|
||||
new_sharding_strategy = ShardingStrategy(new_name,
|
||||
output_sharding_spec=new_sharding_spec_for_output,
|
||||
compute_cost=new_compute_cost,
|
||||
communication_cost=new_communication_cost,
|
||||
memory_cost=new_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input,
|
||||
sharding_spec_for_weight))
|
||||
|
||||
return new_sharding_strategy
|
||||
|
||||
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
|
||||
# shard on mesh_dim_0
|
||||
new_name = f'S{mesh_dim_0}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_0]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
# shard on mesh_dim_1
|
||||
new_name = f'S{mesh_dim_1}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_1]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
# shard on mesh_dim_0, mesh_dim_1
|
||||
new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_0, mesh_dim_1]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||
'''
|
||||
|
||||
# RS = RS x S and strategies based on it, such as
|
||||
# SS = RS x S
|
||||
self.split_input_channel(0, 1)
|
||||
self.split_input_channel(1, 0)
|
||||
|
||||
# RR = RR x R and strategies based on it, such as
|
||||
# SR = SR x R
|
||||
self.non_split(0, 1)
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
self.split_input_channel_1d(0, 1)
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
self.split_input_both_dim(0, 1)
|
||||
self.split_input_both_dim(1, 0)
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
self.split_input_batch_1d(0, 1)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,552 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['BcastOpHandler']
|
||||
|
||||
|
||||
class BcastOpHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert len(self.predecessor_node) == 2
|
||||
self.lhs_data = self.predecessor_node[0]._meta_data
|
||||
self.rhs_data = self.predecessor_node[1]._meta_data
|
||||
self.lhs = self.predecessor_node[0]
|
||||
self.rhs = self.predecessor_node[1]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
shape = list(input_.shape)
|
||||
|
||||
# padding the shape to the same length as output_data
|
||||
while len(shape) < self.output_data.dim():
|
||||
shape.insert(0, 1)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
# if the sharding happens on a size one dimension, we should record it as R.
|
||||
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||
for dim_index, _ in dim_partition_dict.items():
|
||||
if shape[dim_index] == 1:
|
||||
processed_dim_partition_dict.pop(dim_index)
|
||||
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
|
||||
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=shape,
|
||||
dim_partition_dict=processed_dim_partition_dict)
|
||||
|
||||
return sharding_spec
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
lhs_matrix_shape = self.lhs_data.shape[-2:]
|
||||
rhs_matrix_shape = self.rhs_data.shape[-2:]
|
||||
batch_dimensions_shape = self.output_data.shape[:-2]
|
||||
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
|
||||
compute_cost = reduce(
|
||||
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
nodes = self.predecessor_node
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape.insert(0, 1)
|
||||
new_entire_shape = torch.Size(new_entire_shape)
|
||||
new_device_mesh = input_sharding_spec.device_mesh
|
||||
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||
entire_shape=new_entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
|
||||
return resharding_costs
|
||||
|
||||
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
|
||||
|
||||
sharding_spec_list = []
|
||||
check_duplicated_list = []
|
||||
for output_dim_partition_dict in dim_partition_list:
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
break
|
||||
sharding_seq = output_sharding_spec.sharding_sequence
|
||||
if sharding_seq not in check_duplicated_list:
|
||||
check_duplicated_list.append(sharding_seq)
|
||||
sharding_spec_list.append(output_sharding_spec)
|
||||
|
||||
return sharding_spec_list
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_dim_partition_list = []
|
||||
dim_size = self.output_data.dim()
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
output_dim_partition_list.append({})
|
||||
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
|
||||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
|
||||
|
||||
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_output.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||
memory_cost = self.output_data.numel() / sharding_size
|
||||
compute_cost = memory_cost
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
##############################################
|
||||
#used to generate strategies for torch.matmul#
|
||||
##############################################
|
||||
@ignore_sharding_exception
|
||||
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
|
||||
# this dim partition dict only describes the batch dimensions, but in this scenario,
|
||||
# matrix dimensions are fully replicated, so it do not need extra process.
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
batch_sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
|
||||
# in this case, total_sharding_size is equal to the batch sharding size
|
||||
memory_cost = self.output_data.numel() / batch_sharding_size
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(batch_sharding_size)
|
||||
|
||||
# in this case, no communication takes place.
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
# In this scenario, matrix dimensions will be sharded on 'i' dimension.
|
||||
|
||||
# in this case, the matrix dimensions of lhs is sharded on 'i' dimension.
|
||||
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of rhs is fully replicated.
|
||||
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
|
||||
# in this case, the matrix dimensions of output is sharded on 'i' dimension.
|
||||
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix})
|
||||
|
||||
# generate sharding specs
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_dims = []
|
||||
|
||||
# append batch sharding dims
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
|
||||
# append the sharding dims on matrix dimension
|
||||
for mesh_dim in mesh_dim_on_matrix:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
|
||||
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / total_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
# In this scenario, matrix dimensions will be sharded on 'k' dimension.
|
||||
|
||||
# in this case, the matrix dimensions of lhs is sharded on 'k' dimension.
|
||||
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of rhs is sharded on 'k' dimension.
|
||||
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of output is fully replicated.
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
|
||||
# generate sharding specs
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_dims = []
|
||||
batch_sharding_dims = []
|
||||
# append batch sharding dims
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
|
||||
# append the sharding dims on matrix dimension
|
||||
for mesh_dim in mesh_dim_on_matrix:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
|
||||
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
|
||||
|
||||
# in this case, output_data is fully replicated on matrix dimensions.
|
||||
memory_cost = self.output_data.numel() / batch_sharding_size
|
||||
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during forward activation computation.
|
||||
if len(mesh_dim_on_matrix) == 1:
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
|
||||
else:
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
# In this scenario, matrix dimensions will be is sharded on 'j' dimension.
|
||||
|
||||
# in this case, the matrix dimensions of lhs is fully replicated.
|
||||
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
|
||||
# in this case, the matrix dimensions of rhs is sharded on 'j' dimension.
|
||||
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of output is sharded on 'j' dimension.
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix})
|
||||
|
||||
# generate sharding specs
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_dims = []
|
||||
|
||||
# append batch sharding dims
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
|
||||
# append the sharding dims on matrix dimension
|
||||
for mesh_dim in mesh_dim_on_matrix:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
|
||||
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / total_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during backward activation computation.
|
||||
if len(mesh_dim_on_matrix) == 1:
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
|
||||
else:
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list):
|
||||
self._split_dim_i(dim_partition_dict, mesh_dim_list)
|
||||
self._split_dim_k(dim_partition_dict, mesh_dim_list)
|
||||
self._split_dim_j(dim_partition_dict, mesh_dim_list)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
||||
dim_partition_dict_for_rhs = {-2: [mesh_dim_1]}
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
|
||||
dim_partition_dict_for_output = {-2: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
|
||||
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / output_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during forward activation computation.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
||||
dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
|
||||
dim_partition_dict_for_output = {-1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
|
||||
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / output_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during forward and backward activation computation.
|
||||
communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
|
||||
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
communication_cost = communication_cost_backward_activation + communication_cost_forward_activation
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
||||
dim_partition_dict_for_rhs = {-1: [mesh_dim_1]}
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
|
||||
dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
|
||||
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / output_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during backward activation computation.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _registry_2d_strategies_for_matmul(self):
|
||||
self._split_lhs_space_both_contract(0, 1)
|
||||
self._split_lhs_space_both_contract(1, 0)
|
||||
self._split_rhs_space_both_contract(0, 1)
|
||||
self._split_rhs_space_both_contract(1, 0)
|
||||
self._split_lhs_space_rhs_space(0, 1)
|
||||
self._split_lhs_space_rhs_space(1, 0)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
MESH_DIM_LIST = [0, 1]
|
||||
if self.node.target != torch.matmul:
|
||||
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
|
||||
for output_sharding_spec in output_sharding_specs:
|
||||
self._register_strategy(output_sharding_spec)
|
||||
else:
|
||||
# we only care about the non-computing dimensions,
|
||||
# therefore, we omit the last two dimensions.
|
||||
dim_size = self.output_data.dim() - 2
|
||||
|
||||
# Both device mesh axises are uesd on batch dimensions
|
||||
dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
|
||||
for dim_partition_dict in dim_partition_dicts_2d:
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||
|
||||
# Only one device mesh axis is uesd on batch dimensions
|
||||
for mesh_dim_index in [0, 1]:
|
||||
dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
|
||||
for dim_partition_dict in dim_partition_dicts_1d:
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
|
||||
|
||||
# No device mesh axis is uesd on batch dimensions
|
||||
dim_partition_dict_on_batch_dim = {}
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim)
|
||||
self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST)
|
||||
self._registry_2d_strategies_for_matmul()
|
|
@ -1,609 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
||||
class ConvHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For Conv1d, the dim of input data should be 3([N, C, L]).
|
||||
For Conv2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
channel_out(int): The out channel of the conv weight.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# 1D: (L) * N * Cout * Cin * kernel
|
||||
# 2D: (H * W) * N * Cout * Cin * kernel
|
||||
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
||||
output_size = self.output_data.shape[2:]
|
||||
output_size_product = reduce(operator.mul, output_size, 1)
|
||||
input_size = self.input_data.shape[2:]
|
||||
input_size_product = reduce(operator.mul, input_size, 1)
|
||||
kernel_size = self.weight.shape[2:]
|
||||
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
||||
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = self.input_data.numel()
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation during forward
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the input activation grad
|
||||
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
|
||||
mesh_dim_1)
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# total communication cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation in forward phase.
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# compute the total cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_weight
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
|
||||
# This strategy do not need to do all_reduce operation to compute the input activation grad
|
||||
communication_cost_backward_activation = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# compute total cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# This strategy do NOT need all_reduce during forward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_weight_out_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
# compute the total communication cost
|
||||
communication_cost = communication_cost_backward_weight + communication_cost_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
|
||||
self.device_mesh.shape[mesh_dim_1])
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute communication cost during forward phase
|
||||
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_forward_activation, 0)
|
||||
# This strategy do NOT need do all_reduce during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, conv, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
|
||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
|
||||
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
|
||||
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy_into_strategies_vector()
|
||||
for strategy in conv_handler.strategies_vector:
|
||||
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
|
||||
|
||||
Output:
|
||||
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
|
||||
S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
|
||||
RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
||||
RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
||||
RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
||||
RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
||||
RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]}
|
||||
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
|
||||
# SR = SR x RR
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
|
||||
# RR = RS x SR
|
||||
self.split_input_in_channel_weight_in_channel(0)
|
||||
self.split_input_in_channel_weight_in_channel(1)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_weight_out_channel(0)
|
||||
self.split_weight_out_channel(1)
|
||||
|
||||
# RR= RR x RR
|
||||
self.non_split()
|
||||
|
||||
# S01R = S01R x RR
|
||||
self.split_1d_parallel_on_input_batch(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
CONV_STRATEGIES_LIST = [
|
||||
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
|
||||
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
|
||||
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
|
||||
]
|
|
@ -1,756 +0,0 @@
|
|||
import operator
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from .operator_handler import OperatorHandler
|
||||
from .strategy_generator import IntermediateStrategy, StrategyGenerator
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
||||
class DotProductStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
|
||||
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
|
||||
do not consider bias here.
|
||||
"""
|
||||
|
||||
def validate(self, input, other):
|
||||
assert input.dim() == 1 and other.dim() == 1
|
||||
|
||||
def no_split(self):
|
||||
name = f'R = R dot R'
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_one_dim(self, mesh_dim):
|
||||
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# do not split dimensions for dot product
|
||||
# R = R dot R
|
||||
strategy_list.append(self.no_split())
|
||||
|
||||
# split two tensors in the same dimensions
|
||||
# S = S dot S
|
||||
strategy_list.append(self.split_one_dim(0))
|
||||
strategy_list.append(self.split_one_dim(1))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class MatVecStrategyGenerator(StrategyGenerator):
|
||||
|
||||
def validate(self, input, other) -> bool:
|
||||
assert input.dim() > 1 and other.dim() == 1
|
||||
|
||||
def no_split(self):
|
||||
name = "R = R x R"
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_input_batch(self, mesh_dim):
|
||||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# no split
|
||||
strategy_list.append(self.no_split())
|
||||
|
||||
# split the batch dim for the first tensor only
|
||||
strategy_list.append(self.split_input_batch(0))
|
||||
strategy_list.append(self.split_input_batch(1))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
||||
|
||||
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
||||
|
||||
Args:
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
||||
"""
|
||||
|
||||
def __init__(self, is_linear: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_linear = is_linear
|
||||
|
||||
# as the weight for the linear module is transposed, we can compute
|
||||
# the correponding dimension indexfor convenience
|
||||
if is_linear:
|
||||
self.dim_q = 0
|
||||
self.dim_p = 1
|
||||
else:
|
||||
self.dim_q = 1
|
||||
self.dim_p = 0
|
||||
|
||||
def validate(self, input, other, bias) -> bool:
|
||||
# make sure the second tensor is a 2D tensor
|
||||
assert input.dim() > 0 and other.dim() == 2
|
||||
|
||||
# make sure bias is of the same dimension
|
||||
if self.is_linear:
|
||||
assert bias is None or bias.shape[-1] == other.shape[0]
|
||||
else:
|
||||
assert bias is None or bias.shape[-1] == other.shape[1]
|
||||
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_0],
|
||||
self.dim_q: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
dim_partition_dict = {
|
||||
"input": {},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
return IntermediateStrategy(name=name,
|
||||
dim_partition_dict=dim_partition_dict,
|
||||
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
|
||||
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict = {
|
||||
"input": {},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
|
||||
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
Generate sharding strategies for the batched matrix multiplication.
|
||||
|
||||
A batched matrix multiplication can be viewed as
|
||||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
"""
|
||||
|
||||
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_torch_bmm = is_torch_bmm
|
||||
|
||||
def validate(self, input, other, bias) -> bool:
|
||||
if self.is_torch_bmm:
|
||||
assert input.shape == other.shape
|
||||
assert input.dim() > 2
|
||||
assert other.shape[-1] == bias.shape[0]
|
||||
else:
|
||||
# TODO: validate these inputs are broadcastable
|
||||
pass
|
||||
|
||||
def split_one_batch_dim(self):
|
||||
if 1 in self.device_mesh.mesh_shape:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
else:
|
||||
return None
|
||||
|
||||
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_one_batch_dim(self, mesh_dim):
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: mesh_dim_0,
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# split only the batch dimension
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
strategy = self.split_one_batch_dim()
|
||||
if strategy:
|
||||
strategy_list.append(strategy)
|
||||
|
||||
# split batch dim of two inputs and the i dim of the first tensor
|
||||
# SbSi = SbSi x Sb
|
||||
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
|
||||
|
||||
# split batch dim of two inputs and the j of the second tensor
|
||||
# SbSj = Sb x SbSj
|
||||
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
|
||||
|
||||
# split batch dim of two inputs and the k dim of two inputs
|
||||
# Sb = SbSk x SbSk, need to all-reduce by k dim
|
||||
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
|
||||
|
||||
# split two batch dim
|
||||
strategy_list.append(self.split_two_batch_dim(0, 1))
|
||||
strategy_list.append(self.split_two_batch_dim(1, 0))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class DotHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
|
||||
# TODO: consider bias addition
|
||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# linear layer weight is transposed during init
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute computation cost
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
|
||||
|
||||
# create and register strategy
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# since weight of the linear layer is transposed
|
||||
# the actual dim to be sharded is 1
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
|
||||
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
|
||||
communication_cost = communication_cost_weight_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
activation_memory_cost, 0)
|
||||
communication_cost = communication_cost_forward_activation
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
input_grad_memory_cost, 0)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
|
||||
|
||||
Output:
|
||||
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_lhs_space_rhs_space(0, 1)
|
||||
self.split_lhs_space_rhs_space(1, 0)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_lhs_space_both_contract(0, 1)
|
||||
self.split_lhs_space_both_contract(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_rhs_space_both_contract(0, 1)
|
||||
self.split_rhs_space_both_contract(1, 0)
|
||||
|
||||
# RR= RS x SR
|
||||
self.recompute_split_both_contract(0)
|
||||
self.recompute_split_both_contract(1)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_rhs_space_only(0)
|
||||
self.split_rhs_space_only(1)
|
||||
|
||||
# S01R = S01R x RR
|
||||
self.split_lhs_1st_dim_1d(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
self.split_lhs_2nd_dim_1d(0, 1)
|
||||
|
||||
# RS01 = RR x RS01
|
||||
self.split_rhs_2nd_dim_1d(0, 1)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,179 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['EmbeddingHandler']
|
||||
|
||||
|
||||
class EmbeddingHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
input_shape = self.input_data.shape
|
||||
weight_shape = self.weight.shape
|
||||
input_shape_product = reduce(operator.mul, input_shape, 1)
|
||||
weight_shape_product = reduce(operator.mul, weight_shape, 1)
|
||||
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = self.input_data.numel()
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {2: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward_activation = 0
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
# RRS = RR x SS
|
||||
self.split_weight_both_dim(0, 1)
|
||||
self.split_weight_both_dim(1, 0)
|
||||
|
||||
# SSR = SS x RR
|
||||
self.split_input_both_dim(0, 1)
|
||||
self.split_input_both_dim(1, 0)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,241 +0,0 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['LayerNormHandler']
|
||||
|
||||
|
||||
class LayerNormHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||
|
||||
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.bias = self.module_named_parameters['bias']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# TODO: a constant coefficient need to be added.
|
||||
|
||||
norm_kernel_size = self.weight.shape
|
||||
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
|
||||
input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
|
||||
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
|
||||
norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
|
||||
forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
|
||||
# the total cost is input_batch_product * norm_kernel_product
|
||||
backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||
compute_cost = forward_compute_cost + backward_compute_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
# this operation will not change the shape of input
|
||||
numel_input = numel_output
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||
dim_partition_dict_for_input = dim_partition
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = dim_partition
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
|
||||
sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
|
||||
sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
total_mesh_dim_list = []
|
||||
for mesh_dim_list in dim_partition.values():
|
||||
total_mesh_dim_list.extend(mesh_dim_list)
|
||||
|
||||
# This strategy do not need to do all_reduce operation for activation
|
||||
communication_cost_forward_activation = 0
|
||||
communication_cost_backward_activation = 0
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
|
||||
total_mesh_dim_list[0])
|
||||
else:
|
||||
assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
total_sharding_size = 1
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||
'''
|
||||
|
||||
# SR = SR x R with single mesh dim on batch dimensions
|
||||
self.split_input_batch_single_mesh_dim(0)
|
||||
self.split_input_batch_single_mesh_dim(1)
|
||||
|
||||
# SR = SR x R with both mesh dims on batch dimensions
|
||||
self.split_input_batch_both_mesh_dim(0, 1)
|
||||
|
||||
# RR = RR x R
|
||||
self.non_split()
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,149 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
from webbrowser import Opera
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
__all__ = ['OperatorHandler']
|
||||
|
||||
|
||||
class OperatorHandler(ABC):
|
||||
'''
|
||||
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||
|
||||
Args:
|
||||
node (Node): the input node in node argument list.
|
||||
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
handle_backward: bool = True):
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.handle_backward = handle_backward
|
||||
|
||||
# find the module and its parameters associated with this node
|
||||
# this can be used to compute the compute/communication/sharding cost
|
||||
if self.node.op == 'call_module':
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
named_parameters = list(module.named_parameters(recurse=False))
|
||||
# convert named parameters from list to dict
|
||||
named_parameters = {k: v for k, v in named_parameters}
|
||||
elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
|
||||
module = None
|
||||
parameters = list(self.node.args)[1]
|
||||
if isinstance(parameters, Node):
|
||||
named_parameters = {'weight': parameters._meta_data}
|
||||
else:
|
||||
named_parameters = {}
|
||||
else:
|
||||
module = None
|
||||
named_parameters = None
|
||||
self.module = module
|
||||
self.module_named_parameters = named_parameters
|
||||
|
||||
@abstractmethod
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
"""
|
||||
Register
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
|
||||
sharding_spec_for_input):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
Return:
|
||||
total_memory_cost(float): total memory cost per device with this specific strategy
|
||||
activation_cost(float): the memory cost of activation per device with this specific strategy
|
||||
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
|
||||
'''
|
||||
# compute the size of one element with specific dtype
|
||||
dtype = self.input_data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# compute the memory cost of activation
|
||||
activation_numel = self.output_data.numel()
|
||||
output_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
|
||||
output_mesh_dims.extend(mesh_dims)
|
||||
activation_sharding_size = 1
|
||||
for mesh_dim in output_mesh_dims:
|
||||
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
|
||||
|
||||
# compute the memory cost of weight
|
||||
weight_numel = self.weight.numel()
|
||||
weight_sharding_size = 1
|
||||
weight_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
|
||||
weight_mesh_dims.extend(mesh_dims)
|
||||
for mesh_dim in weight_mesh_dims:
|
||||
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
|
||||
|
||||
# compute the memory cost of input grad
|
||||
input_grad_numel = self.input_data.numel()
|
||||
input_grad_sharding_size = 1
|
||||
input_grad_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in sharding_spec_for_input.items():
|
||||
input_grad_mesh_dims.extend(mesh_dims)
|
||||
for mesh_dim in input_grad_mesh_dims:
|
||||
input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
|
||||
|
||||
memory_cost_forward = activation_memory_cost + weight_memory_cost
|
||||
memory_cost_backward = input_grad_memory_cost + weight_memory_cost
|
||||
|
||||
return (memory_cost_forward,
|
||||
memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
if hasattr(self.node._meta_data, 'dtype'):
|
||||
dtype = self.node._meta_data.dtype
|
||||
else:
|
||||
assert isinstance(self.node._meta_data,
|
||||
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
|
||||
dtype = self.node._meta_data[0].dtype
|
||||
|
||||
nodes = self.predecessor_node
|
||||
return generate_resharding_costs(nodes=nodes,
|
||||
sharding_specs=sharding_specs,
|
||||
count_backward=self.handle_backward,
|
||||
dtype=dtype)
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
return generate_sharding_spec(input_=input_,
|
||||
device_mesh=self.device_mesh,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
"""
|
||||
Compute the flops involved in the node.
|
||||
"""
|
||||
pass
|
|
@ -1,89 +0,0 @@
|
|||
import colorsys
|
||||
import math
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..constants import INFINITY_COST
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
|
||||
class ReshapeHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def register_strategy(self):
|
||||
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||
# For reshape function, to keep the computing correctness we keep the sharding
|
||||
# spec of input is fully replicated. In addition, we will keep the output in
|
||||
# replica status and let the successor node choose the way to resharding the
|
||||
# output node. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for reshape function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict_for_output = {}
|
||||
if isinstance(self.output_data, tuple):
|
||||
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
|
||||
try:
|
||||
if isinstance(self.output_data, tuple):
|
||||
output_sharding_spec = []
|
||||
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
|
||||
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
|
||||
else:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = 0
|
||||
# consider node._meta_data is in type of tuple
|
||||
memory_cost = 0
|
||||
|
||||
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
|
||||
dim_partition_dict_for_replicate_input = {}
|
||||
replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data,
|
||||
dim_partition_dict_for_replicate_input)
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
|
||||
replicate_input_sharding_spec)
|
||||
communication_cost = communication_cost["total"]
|
||||
|
||||
# generate resharding cost
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
self.strategies_vector.append(sharding_strategy)
|
|
@ -1,46 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateStrategy:
|
||||
"""
|
||||
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
||||
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
||||
|
||||
Args:
|
||||
name (str): name of the sharding strategy.
|
||||
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
|
||||
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
|
||||
"""
|
||||
name: str
|
||||
dim_partition_dict: Dict[str, Dict[int, List[int]]]
|
||||
all_reduce_axis: List[int] = None
|
||||
|
||||
|
||||
class StrategyGenerator(ABC):
|
||||
"""
|
||||
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
"""
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, *args, **kwargs) -> bool:
|
||||
"""
|
||||
Validate if the operands are of desired shape.
|
||||
If True, means this generator can be used for the current operation.
|
||||
"""
|
||||
pass
|
|
@ -1,88 +0,0 @@
|
|||
import math
|
||||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
|
||||
INFINITY_COST
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['UnaryElementwiseHandler']
|
||||
|
||||
|
||||
class UnaryElementwiseHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.target
|
||||
submod = self.node.graph.owning_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
if submod_type == torch.nn.Dropout:
|
||||
print(f'predecessor nodes of dropout node are {self.predecessor_node}')
|
||||
input_nodes_len = 0
|
||||
for check_node in self.predecessor_node:
|
||||
if isinstance(check_node._meta_data, torch.Tensor):
|
||||
input_nodes_len += 1
|
||||
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.'
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.input_node = self.predecessor_node[0]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def register_strategy(self):
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
|
||||
for index, strategy in enumerate(self.input_node.strategies_vector):
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = self.output_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[self.input_node] = [
|
||||
0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
self.strategies_vector.append(sharding_strategy)
|
|
@ -1,188 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['WhereHandler']
|
||||
|
||||
|
||||
class WhereHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of torch.where.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# TODO: x or y could be scalar
|
||||
super().__init__(*args, **kwargs)
|
||||
assert len(self.predecessor_node) == 3
|
||||
self.condition_data = self.predecessor_node[0]._meta_data
|
||||
self.x_data = self.predecessor_node[1]._meta_data
|
||||
self.y_data = self.predecessor_node[2]._meta_data
|
||||
self.condition = self.predecessor_node[0]
|
||||
self.x = self.predecessor_node[1]
|
||||
self.y = self.predecessor_node[2]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
shape = list(input_.shape)
|
||||
|
||||
# padding the shape to the same length as output_data
|
||||
while len(shape) < self.output_data.dim():
|
||||
shape.insert(0, 1)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
# if the sharding happens on a size one dimension, we should record it as R.
|
||||
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||
for dim_index, _ in dim_partition_dict.items():
|
||||
if shape[dim_index] == 1:
|
||||
processed_dim_partition_dict.pop(dim_index)
|
||||
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
|
||||
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=shape,
|
||||
dim_partition_dict=processed_dim_partition_dict)
|
||||
|
||||
return sharding_spec
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
lhs_matrix_shape = self.lhs_data.shape[-2:]
|
||||
rhs_matrix_shape = self.rhs_data.shape[-2:]
|
||||
batch_dimensions_shape = self.output_data.shape[:-2]
|
||||
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
|
||||
compute_cost = reduce(
|
||||
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
nodes = self.predecessor_node
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape.insert(0, 1)
|
||||
new_entire_shape = torch.Size(new_entire_shape)
|
||||
new_device_mesh = input_sharding_spec.device_mesh
|
||||
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||
entire_shape=new_entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
total_resharding_cost = total_resharding_cost['total']
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
|
||||
return resharding_costs
|
||||
|
||||
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
|
||||
|
||||
sharding_spec_list = []
|
||||
check_duplicated_list = []
|
||||
for output_dim_partition_dict in dim_partition_list:
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
break
|
||||
sharding_seq = output_sharding_spec.sharding_sequence
|
||||
if sharding_seq not in check_duplicated_list:
|
||||
check_duplicated_list.append(sharding_seq)
|
||||
sharding_spec_list.append(output_sharding_spec)
|
||||
|
||||
return sharding_spec_list
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_dim_partition_list = []
|
||||
dim_size = self.output_data.dim()
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
output_dim_partition_list.append({})
|
||||
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
|
||||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input)
|
||||
|
||||
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
|
||||
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs(
|
||||
[sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_output.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||
memory_cost = self.output_data.numel() / sharding_size
|
||||
compute_cost = memory_cost
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_condition, sharding_spec_for_x,
|
||||
sharding_spec_for_y))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
MESH_DIM_LIST = [0, 1]
|
||||
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
|
||||
for output_sharding_spec in output_sharding_specs:
|
||||
self._register_strategy(output_sharding_spec)
|
|
@ -1,11 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
__all__ = ['SolverOptions']
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolverOptions:
|
||||
"""
|
||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||
"""
|
||||
fast: bool = False
|
|
@ -1,91 +0,0 @@
|
|||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import operator
|
||||
import torch
|
||||
from functools import reduce
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from typing import Dict, List, Union, Tuple, Any
|
||||
from torch.fx.node import Node
|
||||
from .constants import *
|
||||
|
||||
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardingStrategy:
|
||||
'''
|
||||
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
||||
and costs information using in solver.
|
||||
|
||||
Argument:
|
||||
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
||||
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
|
||||
compute_cost(float): Computation cost to complete this strategy.(default to 0)
|
||||
communication_cost(float): Communication cost to complete this strategy.(default to 0)
|
||||
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
|
||||
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.(default to None)
|
||||
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||
'''
|
||||
|
||||
name: str
|
||||
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
|
||||
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
|
||||
compute_cost: float = 0.
|
||||
communication_cost: float = 0.
|
||||
memory_cost: float = 0.
|
||||
resharding_costs: Dict[Node, List[float]] = None
|
||||
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
|
||||
# Therefore, we could process them at the specific op(operator.getitem)
|
||||
input_shardings: List[ShardingSpec] = None
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
'''
|
||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||
strategies of the node.
|
||||
|
||||
Argument:
|
||||
node (Node): node for which the list of sharding strategies are generated.
|
||||
'''
|
||||
|
||||
def __init__(self, node: Node):
|
||||
super().__init__()
|
||||
self.node = node
|
||||
# fetch its input and output nodes
|
||||
# TODO: placeholder input nodes
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||
if self.node.op == 'output':
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
||||
self.successor_nodes = list(node.users.keys())
|
||||
|
||||
def check_merge(self):
|
||||
merge_label = False
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.target
|
||||
root_module = self.node.graph.owning_module
|
||||
submod = root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
# merge elementwise module node into source nodes
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if submod_type in ELEMENTWISE_MODULE_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_function':
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||
merge_label = True
|
||||
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
||||
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
||||
merge_label = True
|
||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
return merge_label
|
|
@ -1,469 +0,0 @@
|
|||
import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
||||
try:
|
||||
import pulp
|
||||
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
|
||||
except:
|
||||
warnings.warn(f'please install the pulp')
|
||||
|
||||
__all___ = ['Solver']
|
||||
|
||||
|
||||
class Solver:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
memory_increasing_coefficient: float = 1.3):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
self.graph = graph
|
||||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
self.graph_analyser = graph_analyser
|
||||
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
self.strategy_map = self.strategies_constructor.strategy_map
|
||||
self.memory_budget = memory_budget
|
||||
self.solution_numbers = solution_numbers
|
||||
if self.solution_numbers > 1:
|
||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
for node_index, node in enumerate(self.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
input_strategies_vector = node.args[0].strategies_vector
|
||||
input_best_strategy_index = self.last_s_val[node_index - 1]
|
||||
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
||||
for strategy_index, strategy in enumerate(node.strategies_vector):
|
||||
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
||||
self.last_s_val[node_index] = strategy_index
|
||||
break
|
||||
|
||||
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||
node_index_dict = {}
|
||||
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||
node_index_dict[strategies_vector.node] = index
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
# prepare strategies_len
|
||||
strategies_len = []
|
||||
for node in self.nodes:
|
||||
strategies_len.append(self.cost_graph.node_lens[node])
|
||||
strategies_len = np.array(strategies_len)
|
||||
|
||||
# prepare following_nodes
|
||||
following_nodes = self.cost_graph.following_dict
|
||||
index_following_nodes = {}
|
||||
for src, target in following_nodes.items():
|
||||
src_index = self.node_index_dict[src]
|
||||
target_index = self.node_index_dict[target]
|
||||
index_following_nodes[src_index] = target_index
|
||||
following_nodes = index_following_nodes
|
||||
for index in range(node_nums):
|
||||
if index not in following_nodes:
|
||||
following_nodes[index] = -1
|
||||
|
||||
# prepare edge_pairs and resharding costs
|
||||
edge_pairs = []
|
||||
resharding_costs = []
|
||||
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
||||
src_node = pairs[0]
|
||||
dst_node = pairs[1]
|
||||
src_node_index = self.node_index_dict[src_node]
|
||||
dst_node_index = self.node_index_dict[dst_node]
|
||||
edge_pairs.append(src_node_index)
|
||||
edge_pairs.append(dst_node_index)
|
||||
|
||||
for i in range(strategies_len[src_node_index]):
|
||||
for j in range(strategies_len[dst_node_index]):
|
||||
resharding_costs.append(edge_cost[(i, j)])
|
||||
edge_pairs = np.array(edge_pairs)
|
||||
resharding_costs = np.array(resharding_costs)
|
||||
|
||||
# prepare liveness_set
|
||||
liveness_set = self.liveness_list
|
||||
|
||||
# omit alias_set now
|
||||
alias_set = None
|
||||
alias_convert_costs = None
|
||||
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
compute_costs = []
|
||||
communication_costs = []
|
||||
memory_costs = []
|
||||
extra_node_costs = self.cost_graph.extra_node_costs
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
node = strategies_vector.node
|
||||
for index, strategy in enumerate(strategies_vector):
|
||||
compute_costs.append(strategy.compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
# cost into
|
||||
if node in extra_node_costs:
|
||||
origin_communication_cost = strategy.communication_cost
|
||||
extra_node_cost = extra_node_costs[node][index]
|
||||
communication_cost = origin_communication_cost + extra_node_cost
|
||||
communication_costs.append(communication_cost)
|
||||
else:
|
||||
communication_costs.append(strategy.communication_cost)
|
||||
# temporarily we just consider the forward memory cost
|
||||
memory_cost = strategy.memory_cost
|
||||
if isinstance(memory_cost, tuple):
|
||||
memory_costs.append(memory_cost[0])
|
||||
else:
|
||||
memory_costs.append(memory_cost)
|
||||
compute_costs = np.array(compute_costs)
|
||||
communication_costs = np.array(communication_costs)
|
||||
memory_costs = np.array(memory_costs)
|
||||
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
|
||||
tic = time.time()
|
||||
|
||||
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert len(strategies_len) == node_nums, "strategies_len"
|
||||
|
||||
def get_non_zero_index(binary_vector):
|
||||
"""
|
||||
Get the index of non-zero item in a vector.
|
||||
"""
|
||||
ct = 0
|
||||
ret = None
|
||||
for i, elem in enumerate(binary_vector):
|
||||
if pulp.value(elem):
|
||||
ret = i
|
||||
ct += 1
|
||||
|
||||
assert ct == 1
|
||||
return ret
|
||||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# A = alias_set.reshape((-1, 2)) # noqa
|
||||
# for (i, j) in A:
|
||||
# prod_length = strategies_len[i] * strategies_len[j]
|
||||
# v.append(alias_convert_costs[pt:pt + prod_length])
|
||||
# pt += prod_length
|
||||
# assert pt == len(alias_convert_costs)
|
||||
|
||||
# L = [] # noqa
|
||||
# pt = node_nums
|
||||
# for i in range(node_nums):
|
||||
# length = liveness_set[i]
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
v = []
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
d = []
|
||||
m = []
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[pt:pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
|
||||
|
||||
# 1. Create variables
|
||||
|
||||
#############################
|
||||
# create variables for node #
|
||||
#############################
|
||||
s = []
|
||||
num_nodes = 0
|
||||
reverse_follow_backpatch = []
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
else:
|
||||
s.append(None)
|
||||
reverse_follow_backpatch.append(i)
|
||||
|
||||
for i in reverse_follow_backpatch:
|
||||
s[i] = s[s_follow[i]]
|
||||
|
||||
#############################
|
||||
# create variables for edge #
|
||||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
######################################
|
||||
# set a initial value for warm start #
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
s[idx][i].fixValue()
|
||||
|
||||
# 3. Objective
|
||||
prob = LpProblem("myProblem", LpMinimize)
|
||||
###################################################################
|
||||
# computing the node cost(computing cost and communication cost) #
|
||||
###################################################################
|
||||
obj = 0
|
||||
for i in range(node_nums):
|
||||
assert len(s[i]) == len(c[i])
|
||||
assert len(s[i]) == len(d[i])
|
||||
|
||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
# computing the edge cost(resharding cost) #
|
||||
#############################################
|
||||
for i in range(len(E)):
|
||||
assert len(e[i]) == len(r[i])
|
||||
obj += lpDot(e[i], r[i])
|
||||
|
||||
prob += obj
|
||||
|
||||
# 4. Constraints
|
||||
# (a). specified by `cat="Binary"`
|
||||
|
||||
# (b)
|
||||
#################################################
|
||||
# make sure each node only choose one strategy #
|
||||
#################################################
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
prob += lpSum(s[i]) == 1
|
||||
|
||||
# (c)
|
||||
#################################################
|
||||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
# (e)
|
||||
prob += lpSum(e[idx]) == 1
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||
|
||||
# (h)
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# alias_set = set()
|
||||
# for (idx, (i, j)) in enumerate(A):
|
||||
# R = len(s[i]) # noqa
|
||||
# C = len(s[j]) # noqa
|
||||
# if (i, j) in alias_set:
|
||||
# raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
# alias_set.add((i, j))
|
||||
# alias_set.add((j, i))
|
||||
|
||||
# for row in range(len(s[i])):
|
||||
# for col in range(len(s[j])):
|
||||
# if v[idx][row * C + col] > 0.5:
|
||||
# prob += s[i][row] + s[j][col] <= 1
|
||||
|
||||
verbose = True
|
||||
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||
|
||||
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||
prob.solve(solver)
|
||||
|
||||
status = prob.status
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(objective) if objective is not None else -1.0
|
||||
if verbose:
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||
f"Time: {time.time() - tic}")
|
||||
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||
"Please increase the memory budget.")
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||
for i in range(node_nums):
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
||||
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
||||
if verbose and r[idx][e_val[idx]] > 0:
|
||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||
|
||||
self.last_s_val = list(s_val)
|
||||
self._recover_merged_node_strategy()
|
||||
self.last_objective = objective
|
||||
|
||||
if objective > INFINITY_COST:
|
||||
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||
|
||||
return self.last_s_val, e_val, self.last_objective, status
|
||||
|
||||
def call_solver_serialized_args(self):
|
||||
"""
|
||||
Call the solver with serialized arguments and handle python errors. Additionally,
|
||||
we could give a serious of solutions with different memory budget.
|
||||
"""
|
||||
if self.solution_numbers == 1:
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
|
||||
return ret
|
||||
|
||||
origin_memory_budget = self.memory_budget
|
||||
memory_budget_list = [
|
||||
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
|
||||
]
|
||||
ret_list = []
|
||||
for memory_budget in memory_budget_list:
|
||||
self.memory_budget = memory_budget
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
ret_list.append(ret)
|
||||
|
||||
return ret_list
|
|
@ -1,426 +0,0 @@
|
|||
import builtins
|
||||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from .constants import *
|
||||
from .op_handler import *
|
||||
from .options import SolverOptions
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
"""
|
||||
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||
|
||||
Args:
|
||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.device_mesh = device_mesh
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.solver_options = solver_options
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||
'''
|
||||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
if strategy.name not in name_checklist:
|
||||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
remove_list.append(strategy)
|
||||
|
||||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def _is_bcast_matmul(self, node):
|
||||
is_bcast_matmul = False
|
||||
if node.target is torch.matmul and len(node.args) == 2:
|
||||
lhs_data = node.args[0]._meta_data
|
||||
rhs_data = node.args[1]._meta_data
|
||||
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
|
||||
is_bcast_matmul = True
|
||||
return is_bcast_matmul
|
||||
|
||||
def build_strategies_and_cost(self):
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
input_nodes_len = 0
|
||||
for check_node in strategies_vector.predecessor_nodes:
|
||||
if isinstance(check_node._meta_data, torch.Tensor):
|
||||
input_nodes_len += 1
|
||||
# input_nodes_len = len(strategies_vector.predecessor_nodes)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
|
||||
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
|
||||
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for placeholder
|
||||
name = 'Replica Placeholder'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_placeholder = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_placeholder)
|
||||
|
||||
# get_attr node
|
||||
if node.op == 'get_attr':
|
||||
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for get_attr
|
||||
name = 'Replica Attribute'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
# call_module node
|
||||
if node.op == 'call_module':
|
||||
|
||||
target = node.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
|
||||
# conv module
|
||||
if submod_type in CONV_MODULE_OP:
|
||||
# use ConvHandler to create sharding strategies for conv module node
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear module
|
||||
elif submod_type in LINEAR_MODULE_OP:
|
||||
# use DotHandler to create sharding strategies for linear module node
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
dot_handler.register_strategy()
|
||||
|
||||
# element-wise module
|
||||
elif submod_type in ELEMENTWISE_MODULE_OP:
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# BatchNormNd module
|
||||
elif submod_type in BATCHNORM_MODULE_OP:
|
||||
# create sharding strategy for element-wise module
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
|
||||
norm_handler.register_strategy()
|
||||
# for strategy in norm_handler.strategies_vector:
|
||||
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
# assert False
|
||||
|
||||
# MaxPool module
|
||||
elif submod_type in POOL_MODULE_OP:
|
||||
# TODO: add sharding constraints on image dimension
|
||||
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
|
||||
|
||||
# create sharding strategy for element-wise module
|
||||
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise module, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise module.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# embedding module
|
||||
elif submod_type in EMBEDDING_MODULE_OP:
|
||||
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
|
||||
embedding_handler.register_strategy()
|
||||
|
||||
# layernorm module
|
||||
elif submod_type in LAYERNORM_MODULE_OP:
|
||||
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
|
||||
layernorm_handler.register_strategy()
|
||||
# other module
|
||||
else:
|
||||
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||
|
||||
# call_function node
|
||||
if node.op == 'call_function':
|
||||
target = node.target
|
||||
# conv function
|
||||
if target in CONV_FUNC_OP:
|
||||
# use ConvHandler to create sharding strategies for conv node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear function
|
||||
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
|
||||
# use DotHandler to create sharding strategies for linear node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
linear_handler.register_strategy()
|
||||
|
||||
# where function
|
||||
elif target == torch.where:
|
||||
if input_nodes_len == 1:
|
||||
# both of x and y are scalar
|
||||
pass
|
||||
|
||||
elif input_nodes_len == 2:
|
||||
# one of x or y is type of scalar
|
||||
pass
|
||||
|
||||
else:
|
||||
# general case
|
||||
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
|
||||
where_handler.register_strategy()
|
||||
|
||||
# reshape function
|
||||
elif target in RESHAPE_FUNC_OP:
|
||||
# use ReshapeHandler to create sharding strategies for rehsape node
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# bcast op
|
||||
elif target in BCAST_FUNC_OP:
|
||||
if isinstance(node._meta_data, torch.Tensor):
|
||||
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
|
||||
bcast_op_handler.register_strategy()
|
||||
|
||||
# torch.var_mean
|
||||
elif target == torch.var_mean:
|
||||
dim = node.kwargs['dim']
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
entire_shape_input = input_sharding_spec.entire_shape
|
||||
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
|
||||
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
|
||||
if dim in dim_partition_dict_input:
|
||||
# We need to make the action dimension in replicate status
|
||||
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
|
||||
dim_partition_dict_for_input.pop(dim)
|
||||
new_input_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_input,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
entire_shape_output = deepcopy(entire_shape_input)
|
||||
entire_shape_output.pop(dim)
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[new_input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[new_input_sharding_spec])
|
||||
|
||||
else:
|
||||
entire_shape_output = deepcopy(entire_shape_input)
|
||||
entire_shape_output.pop(dim)
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partion_dict=dim_partition_dict_input)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# operator.getitem
|
||||
elif target == operator.getitem:
|
||||
index = node.args[1]
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
if isinstance(strategy.output_sharding_spec, ShardingSpec):
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
else:
|
||||
input_sharding_spec = strategy.output_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec],
|
||||
index=index)
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[strategy.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# torch.arange function
|
||||
elif target == torch.arange:
|
||||
name = f'FULLY REPLICATED ARANGE'
|
||||
entire_shape_output = node._meta_data.shape
|
||||
dim_partition_dict_for_output = {}
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
memory_cost = node._meta_data.numel()
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=0,
|
||||
memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# op list to be processed to support gpt2
|
||||
elif target in (builtins.getattr, operator.le, torch.addmm):
|
||||
pass
|
||||
# other function
|
||||
else:
|
||||
raise RuntimeError(f'{target} function is NOT supported now.')
|
||||
|
||||
# call_method node
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
if method in (torch.Tensor.size,):
|
||||
pass
|
||||
elif method in ELEMENTWISE_METHOD_OP:
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
elif method in RESHAPE_METHOD_OP:
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
# print(strategies_vector)
|
||||
# if len(strategies_vector) == 0:
|
||||
# print(node)
|
||||
# assert False
|
||||
else:
|
||||
raise RuntimeError(f'{method} function is NOT supported now.')
|
||||
|
||||
# output node
|
||||
if node.op == 'output':
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for output
|
||||
name = 'Replica Output'
|
||||
input_nodes = strategies_vector.predecessor_nodes
|
||||
input_sharding_specs = []
|
||||
for input_node in input_nodes:
|
||||
dim_partition_dict_for_input = {}
|
||||
entire_shape = input_node._meta_data.shape
|
||||
sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
input_sharding_specs.append(sharding_spec)
|
||||
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = input_sharding_specs
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
|
||||
# clear the resharding cost for the output node
|
||||
# TODO: we may remove this in final version
|
||||
for prev_node, resharding_cost_list in resharding_costs.items():
|
||||
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
||||
|
||||
sharding_strategy_attribute = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=tuple(input_sharding_specs))
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
# remove no strategy nodes
|
||||
remove_list = []
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
if len(strategies_vector) == 0:
|
||||
remove_list.append(strategies_vector.node)
|
||||
for node in remove_list:
|
||||
if node.strategies_vector in self.leaf_strategies:
|
||||
self.leaf_strategies.remove(node.strategies_vector)
|
||||
if node in self.strategy_map:
|
||||
self.strategy_map.pop(node)
|
|
@ -8,14 +8,9 @@ from torch.fx.graph import Graph
|
|||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
|
|||
pass
|
||||
|
||||
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
|
||||
shard_option: str):
|
||||
'''
|
||||
This method is used to build the strategy_constructor for the given graph.
|
||||
After this method, each node in the graph will have a strategies_vector which
|
||||
is constructed by the related node handler.
|
||||
'''
|
||||
solver_options = SolverOptions()
|
||||
if solver_preference == 'standard':
|
||||
solver_preference = SolverPerference.STANDARD
|
||||
elif solver_preference == 'tp':
|
||||
solver_preference = SolverPerference.TP
|
||||
elif solver_preference == 'dp':
|
||||
solver_preference = SolverPerference.DP
|
||||
else:
|
||||
raise ValueError(f'Invalid solver_preference: {solver_preference}')
|
||||
|
||||
if dataloader_option == 'replicated':
|
||||
dataloader_option = DataloaderOption.REPLICATED
|
||||
elif dataloader_option == 'distributed':
|
||||
dataloader_option = DataloaderOption.DISTRIBUTED
|
||||
else:
|
||||
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
|
||||
|
||||
if shard_option == 'standard':
|
||||
shard_option = ShardOption.STANDARD
|
||||
elif shard_option == 'shard':
|
||||
shard_option = ShardOption.SHARD
|
||||
elif shard_option == 'shard_last_axis':
|
||||
shard_option = ShardOption.SHARD_LAST_AXIS
|
||||
elif shard_option == 'full_shard':
|
||||
shard_option = ShardOption.FULL_SHARD
|
||||
else:
|
||||
raise ValueError(f'Invalid shard_option: {shard_option}')
|
||||
|
||||
solver_options = SolverOptions(solver_perference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
|
|||
device_mesh: DeviceMesh,
|
||||
memory_budget: float = -1.0,
|
||||
overlap: bool = False,
|
||||
solver_preference: str = 'standard',
|
||||
dataloader_option: str = 'replicated',
|
||||
shard_option: str = 'standard',
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solution_path: str = None,
|
||||
|
@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
|
|||
the memory budget will be infinity.
|
||||
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
|
||||
backward computing.
|
||||
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
||||
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
||||
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
||||
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
||||
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
||||
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
|
|||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph,
|
||||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option)
|
||||
if load_solver_solution:
|
||||
solution = torch.load(solution_path)
|
||||
else:
|
||||
|
@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
|
|||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
solver_preference: str = 'standard',
|
||||
dataloader_option: str = 'replicated',
|
||||
shard_option: str = 'standard',
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
|
@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
|
|||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
||||
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
||||
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
||||
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
||||
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
||||
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
|
|||
rst_to_unpack = initialize_model(model,
|
||||
meta_args,
|
||||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solution_path=solver_solution_path,
|
||||
|
|
|
@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler
|
|||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .option import ShardOption
|
||||
from .output_handler import OutputHandler
|
||||
from .permute_handler import PermuteHandler
|
||||
from .placeholder_handler import PlaceholderHandler
|
||||
|
@ -31,6 +30,6 @@ __all__ = [
|
|||
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
|
||||
'TransposeHandler', 'SplitHandler'
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
|
||||
'SplitHandler'
|
||||
]
|
||||
|
|
|
@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler):
|
|||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(
|
||||
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
|
||||
LinearProjectionStrategyGenerator(op_data_mapping,
|
||||
self.device_mesh,
|
||||
linear_projection_type='linear',
|
||||
solver_perference=self.solver_perference))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -32,19 +32,19 @@ class NodeHandler(ABC):
|
|||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
shard_option: ShardOption = ShardOption.STANDARD,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
shard_option: ShardOption = ShardOption.STANDARD,
|
||||
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shard_option = shard_option
|
||||
self.solver_perference = solver_perference
|
||||
|
||||
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
|
||||
"""
|
||||
|
@ -187,15 +187,24 @@ class NodeHandler(ABC):
|
|||
|
||||
remove_strategy_list = []
|
||||
for strategy in self.strategies_vector:
|
||||
shard_level = 0
|
||||
shard_axis_list = []
|
||||
last_axis = len(self.device_mesh.mesh_shape) - 1
|
||||
for op_data, sharding_spec in strategy.sharding_specs.items():
|
||||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
|
||||
shard_level += len(shard_axis)
|
||||
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
|
||||
for shard_axis in shard_axes:
|
||||
if shard_axis not in shard_axis_list:
|
||||
shard_axis_list.append(shard_axis)
|
||||
|
||||
shard_level = len(shard_axis_list)
|
||||
using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
|
||||
if self.shard_option == ShardOption.SHARD and shard_level == 0:
|
||||
remove_strategy_list.append(strategy)
|
||||
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
|
||||
remove_strategy_list.append(strategy)
|
||||
if self.shard_option == ShardOption.SHARD_LAST_AXIS:
|
||||
if shard_level != 1 or using_last_axis == False:
|
||||
remove_strategy_list.append(strategy)
|
||||
|
||||
for strategy in remove_strategy_list:
|
||||
self.strategies_vector.remove(strategy)
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
from enum import Enum
|
||||
|
||||
__all__ = ['ShardOption']
|
||||
|
||||
|
||||
class ShardOption(Enum):
|
||||
"""
|
||||
This enum class is to define the shard level required in node strategies.
|
||||
|
||||
Notes:
|
||||
STANDARD: We do not add any extra shard requirements.
|
||||
SHARD: We require the node to be shard using at least one device mesh axis.
|
||||
FULL_SHARD: We require the node to be shard using all device mesh axes.
|
||||
"""
|
||||
STANDARD = 0
|
||||
SHARD = 1
|
||||
FULL_SHARD = 2
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue