Update metainfo patch branch (#2517)

* init

* rename and remove useless func

* basic chunk

* add evoformer

* align evoformer

* add meta

* basic chunk

* basic memory

* finish basic inference memory estimation

* finish memory estimation

* fix bug

* finish memory estimation

* add part of index tracer

* finish basic index tracer

* add doc string

* add doc str

* polish code

* polish code

* update active log

* polish code

* add possible region search

* finish region search loop

* finish chunk define

* support new op

* rename index tracer

* finishi codegen on msa

* redesign index tracer, add source and change compute

* pass outproduct mean

* code format

* code format

* work with outerproductmean and msa

* code style

* code style

* code style

* code style

* change threshold

* support check_index_duplicate

* support index dupilictae and update loop

* support output

* update memory estimate

* optimise search

* fix layernorm

* move flow tracer

* refactor flow tracer

* format code

* refactor flow search

* code style

* adapt codegen to prepose node

* code style

* remove abandoned function

* remove flow tracer

* code style

* code style

* reorder nodes

* finish node reorder

* update run

* code style

* add chunk select class

* add chunk select

* code style

* add chunksize in emit, fix bug in reassgin shape

* code style

* turn off print mem

* add evoformer openfold init

* init openfold

* add benchmark

* add print

* code style

* code style

* init openfold

* update openfold

* align openfold

* use max_mem to control stratge

* update source add

* add reorder in mem estimator

* improve reorder efficeincy

* support ones_like, add prompt if fit mode search fail

* fix a bug in ones like, dont gen chunk if dim size is 1

* fix bug again

* update min memory stratege, reduce mem usage by 30%

* last version of benchmark

* refactor structure

* restruct dir

* update test

* rename

* take apart chunk code gen

* close mem and code print

* code format

* rename ambiguous variable

* seperate flow tracer

* seperate input node dim search

* seperate prepose_nodes

* seperate non chunk input

* seperate reorder

* rename

* ad reorder graph

* seperate trace flow

* code style

* code style

* fix typo

* set benchmark

* rename test

* update codegen test

* Fix state_dict key missing issue of the ZeroDDP (#2363)

* Fix state_dict output for ZeroDDP duplicated parameters

* Rewrite state_dict based on get_static_torch_model

* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)

* update codegen test

* update codegen test

* add chunk search test

* code style

* add available

* [hotfix] fix gpt gemini example (#2404)

* [hotfix] fix gpt gemini example

* [example] add new assertions

* remove autochunk_available

* [workflow] added nightly release to pypi (#2403)

* add comments

* code style

* add doc for search chunk

* [doc] updated readme regarding pypi installation (#2406)

* add doc for search

* [doc] updated kernel-related optimisers' docstring (#2385)

* [doc] updated kernel-related optimisers' docstring

* polish doc

* rename trace_index to trace_indice

* rename function from index to indice

* rename

* rename in doc

* [polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust

* rename

* rename

* remove useless function

* [worfklow] added coverage test (#2399)

* [worfklow] added coverage test

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* add doc for trace indice

* [docker] updated Dockerfile and release workflow (#2410)

* add doc

* update doc

* add available

* change imports

* add test in import

* [workflow] refactored the example check workflow (#2411)

* [workflow] refactored the example check workflow

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* Update parallel_context.py (#2408)

* [hotfix] add DISTPAN argument for benchmark (#2412)

* change the benchmark config file

* change config

* revert config file

* rename distpan to distplan

* [workflow] added precommit check for code consistency (#2401)

* [workflow] added precommit check for code consistency

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* adapt new fx

* [workflow] added translation for non-english comments (#2414)

* [setup] refactored setup.py for dependency graph (#2413)

* change import

* update doc

* [workflow] auto comment if precommit check fails (#2417)

* [hotfix] add norm clearing for the overflow step (#2416)

* [examples] adding tflops to PaLM (#2365)

* [workflow]auto comment with test coverage report (#2419)

* [workflow]auto comment with test coverage report

* polish code

* polish yaml

* [doc] added documentation for CI/CD (#2420)

* [doc] added documentation for CI/CD

* polish markdown

* polish markdown

* polish markdown

* [example] removed duplicated stable diffusion example (#2424)

* [zero] add inference mode and its unit test (#2418)

* [workflow] report test coverage even if below threshold (#2431)

* [example] improved the clarity yof the example readme (#2427)

* [example] improved the clarity yof the example readme

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* [ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored

* [workflow] make test coverage report collapsable (#2436)

* [autoparallel] add shard option (#2423)

* [fx] allow native ckpt trace and codegen. (#2438)

* [cli] provided more details if colossalai run fail (#2442)

* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost

* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)

* [ddp] add is_ddp_ignored

[ddp] rename to is_ddp_ignored

* [zero] fix state_dict and load_state_dict

* fix bugs

* [zero] update unit test for ZeroDDP

* [example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code

* [zero] add warning for ignored parameters (#2446)

* [example] updated large-batch optimizer tutorial (#2448)

* [example] updated large-batch optimizer tutorial

* polish code

* polish code

* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)

* [workflow] fixed the on-merge condition check (#2452)

* [workflow] automated the compatiblity test (#2453)

* [workflow] automated the compatiblity test

* polish code

* [autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish

* [workflow] automated bdist wheel build (#2459)

* [workflow] automated bdist wheel build

* polish workflow

* polish readme

* polish readme

* Fix False warning in initialize.py (#2456)

* Update initialize.py

* pre-commit run check

* [examples] update autoparallel tutorial demo (#2449)

* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml

* [cli] fixed hostname mismatch error (#2465)

* [example] integrate autoparallel demo with CI (#2466)

* [example] integrate autoparallel demo with CI

* polish code

* polish code

* polish code

* polish code

* [zero] low level optim supports ProcessGroup (#2464)

* [example] update vit ci script (#2469)

* [example] update vit ci script

* [example] update requirements

* [example] update requirements

* [example] integrate seq-parallel tutorial with CI (#2463)

* [zero] polish low level optimizer (#2473)

* polish pp middleware (#2476)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [example] update gpt gemini example ci test (#2477)

* [zero] add unit test for low-level zero init (#2474)

* [workflow] fixed the skip condition of  example weekly check workflow (#2481)

* [example] stable diffusion add roadmap

* add dummy test_ci.sh

* [example] stable diffusion add roadmap (#2482)

* [CI] add test_ci.sh for palm, opt and gpt (#2475)

* polish code

* [example] titans for gpt

* polish readme

* remove license

* polish code

* update readme

* [example] titans for gpt (#2484)

* [autoparallel] support origin activation ckpt on autoprallel system (#2468)

* [autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code

* [example] fix requirements (#2488)

* [zero] add unit testings for hybrid parallelism  (#2486)

* [hotfix] gpt example titans bug #2493

* polish code and fix dataloader bugs

* [hotfix] gpt example titans bug #2493 (#2494)

* [fx] allow control of ckpt_codegen init (#2498)

* [fx] allow control of ckpt_codegen init

Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. 
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.

* code style

* [example] dreambooth example

* add test_ci.sh to dreambooth

* [autochunk] support autochunk on evoformer (#2497)

* Revert "Update parallel_context.py (#2408)"

This reverts commit 7d5640b9db.

* add avg partition (#2483)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [auto-chunk] support extramsa (#3) (#2504)

* [utils] lazy init. (#2148)

* [utils] lazy init.

* [utils] remove description.

* [utils] complete.

* [utils] finalize.

* [utils] fix names.

* [autochunk] support parsing blocks (#2506)

* [zero] add strict ddp mode (#2508)

* [zero] add strict ddp mode

* [polish] add comments for strict ddp mode

* [zero] fix test error

* [doc] update opt and tutorial links (#2509)

* [workflow] fixed changed file detection (#2515)

Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
dev/gpt2_metainfo_patch
Boyuan Yao 2023-01-27 09:52:21 +08:00 committed by GitHub
parent ce08661eb1
commit 7a58dc5ad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
215 changed files with 8523 additions and 14916 deletions

24
.bdist.json Normal file
View File

@ -0,0 +1,24 @@
{
"build": [
{
"torch_version": "1.11.0",
"cuda_image": "hpcaitech/cuda-conda:10.2"
},
{
"torch_version": "1.11.0",
"cuda_image": "hpcaitech/cuda-conda:11.3"
},
{
"torch_version": "1.12.1",
"cuda_image": "hpcaitech/cuda-conda:10.2"
},
{
"torch_version": "1.12.1",
"cuda_image": "hpcaitech/cuda-conda:11.3"
},
{
"torch_version": "1.12.1",
"cuda_image": "hpcaitech/cuda-conda:11.6"
}
]
}

3
.compatibility Normal file
View File

@ -0,0 +1,3 @@
1.12.0-11.3.0
1.11.0-11.3.0
1.10.1-11.3.0

149
.github/workflows/README.md vendored Normal file
View File

@ -0,0 +1,149 @@
# CI/CD
## Table of Contents
- [CI/CD](#cicd)
- [Table of Contents](#table-of-contents)
- [Overview](#overview)
- [Workflows](#workflows)
- [Checks on Pull Requests](#checks-on-pull-requests)
- [Regular Checks](#regular-checks)
- [Release](#release)
- [Manual Dispatch](#manual-dispatch)
- [Release bdist wheel](#release-bdist-wheel)
- [Dispatch Example Test](#dispatch-example-test)
- [Compatibility Test](#compatibility-test)
- [User Friendliness](#user-friendliness)
- [Configuration](#configuration)
- [Progress Log](#progress-log)
## Overview
Automation makes our development more efficient as the machine automatically run the pre-defined tasks for the contributors.
This saves a lot of manual work and allow the developer to fully focus on the features and bug fixes.
In Colossal-AI, we use [GitHub Actions](https://github.com/features/actions) to automate a wide range of workflows to ensure the robustness of the software.
In the section below, we will dive into the details of different workflows available.
## Workflows
### Checks on Pull Requests
| Workflow Name | File name | Description |
| --------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Build` | `build.yml` | This workflow is triggered when the label `Run build and Test` is assigned to a PR. It will run all the unit tests in the repository with 4 GPUs. |
| `Pre-commit` | `pre_commit.yml` | This workflow runs pre-commit checks for code style consistency. |
| `Report pre-commit failure` | `report_precommit_failure.yml` | This PR will put up a comment in the PR to explain the precommit failure and remedy. This is executed when `Pre-commit` is done |
| `Report test coverage` | `report_test_coverage.yml` | This PR will put up a comment to report the test coverage results. This is executed when `Build` is completed. |
| `Test example` | `auto_example_check.yml` | The example will be automatically tested if its files are changed in the PR |
### Regular Checks
| Workflow Name | File name | Description |
| ----------------------- | ----------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Test example` | `auto_example_check.yml` | This workflow will test all examples every Sunday |
| `Compatibility Test` | `auto_compatibility_test.yml` | This workflow will check the compatiblity of Colossal-AI against PyTorch and CUDA every Sunday. The PyTorch and CUDA versions are specified in `.compatibility`. |
| `Build on 8 GPUs` | `build_gpu_8.yml` | This workflow will run the unit tests everyday with 8 GPUs. |
| `Synchronize submodule` | `submodule.yml` | This workflow will check if any git submodule is updated. If so, it will create a PR to update the submodule pointers. |
| `Close inactive issues` | `close_inactive.yml` | This workflow will close issues which are stale for 14 days. |
### Release
| Workflow Name | File name | Description |
| --------------------------- | ------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Draft GitHub Release Post` | `draft_github_release_post.yml` | Compose a GitHub release post draft based on the commit history. Triggered when the change of `version.txt` is merged. |
| `Release to PyPI` | `release_pypi.yml` | Build and release the wheel to PyPI. Triggered when the change of `version.txt` is merged. |
| `Release Nightly to PyPI` | `release_nightly.yml` | Build and release the nightly wheel to PyPI as `colossalai-nightly`. Automatically executed every Sunday. |
| `Release Docker` | `release_docker.yml` | Build and release the Docker image to DockerHub. Triggered when the change of `version.txt` is merged. |
| `Release bdist wheel` | `release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions. Manually dispatched. See more details in the next section. |
| `Auto Release bdist wheel` | `auto_release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions.Triggered when the change of `version.txt` is merged. Build specificatons are stored in `.bdist.json` |
| `Auto Compatibility Test` | `auto_compatibility_test.yml` | Check Colossal-AI's compatiblity against the PyTorch and CUDA version specified in `.compatibility`. Triggered when `version.txt` is changed in a PR. |
### Manual Dispatch
| Workflow Name | File name | Description |
| ---------------------------- | -------------------------------- | ------------------------------------------------------ |
| `Release bdist wheel` | `release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions. |
| `Dispatch Example Test` | `dispatch_example_check.yml` | Manually test a specified example. |
| `Dispatch Compatiblity Test` | `dispatch_compatiblity_test.yml` | Test PyTorch and Python Compatibility. |
Refer to this [documentation](https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow) on how to manually trigger a workflow.
I will provide the details of each workflow below.
#### Release bdist wheel
Parameters:
- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels) which is regularly updated.
- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda).
- `ref`: input the branch or tag name to build the wheel for this ref.
#### Dispatch Example Test
parameters:
- `example_directory`: the example directory to test. Multiple directories are supported and must be separated by comma. For example, language/gpt, images/vit. Simply input language or simply gpt does not work.
#### Compatibility Test
Parameters:
- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels).
- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda).
> It only test the compatiblity of the main branch
### User Friendliness
| Workflow Name | File name | Description |
| ----------------- | ----------------------- | -------------------------------------------------------------------------------------------------------------------------------------- |
| `issue-translate` | `translate_comment.yml` | This workflow is triggered when a new issue comment is created. The comment will be translated into English if not written in English. |
## Configuration
This section lists the files used to configure the workflow.
1. `.compatibility`
This `.compatibility` file is to tell GitHub Actions which PyTorch and CUDA versions to test against. Each line in the file is in the format `${torch-version}-${cuda-version}`, which is a tag for Docker image. Thus, this tag must be present in the [docker registry](https://hub.docker.com/r/pytorch/conda-cuda) so as to perform the test.
2. `.bdist.json`
This file controls what pytorch/cuda compatible pre-built releases will be built and published. You can add a new entry according to the json schema below if there is a new wheel that needs to be built with AOT compilation of PyTorch extensions.
```json
{
"build": [
{
"torch_version": "",
"cuda_image": ""
},
]
}
```
## Progress Log
- [x] unit testing
- [x] test on PR
- [x] report test coverage
- [x] regular test
- [x] release
- [x] official release
- [x] nightly build
- [x] binary build
- [x] docker build
- [x] draft release post
- [x] pre-commit
- [x] check on PR
- [x] report failure
- [x] example check
- [x] check on PR
- [x] regular check
- [x] manual dispatch
- [x] compatiblity check
- [x] manual dispatch
- [x] auto test when release
- [x] helpers
- [x] comment translation
- [x] submodule update
- [x] close inactive issue

View File

@ -0,0 +1,74 @@
name: Compatibility Test
on:
pull_request:
paths:
- 'version.txt'
- '.compatibility'
# run at 03:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
schedule:
- cron: '0 19 * * 6'
jobs:
matrix_preparation:
name: Prepare Container List
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v3
- id: set-matrix
run: |
IFS=','
DOCKER_IMAGE=()
while read tag; do
DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tag}\"")
done <.compatibility
container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" )
container="[${container}]"
echo "$container"
echo "::set-output name=matrix::{\"container\":$(echo "$container")}"
build:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
timeout-minutes: 120
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install Colossal-AI
run: |
pip install -v --no-cache-dir .
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest tests
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64

143
.github/workflows/auto_example_check.yml vendored Normal file
View File

@ -0,0 +1,143 @@
name: Test Example
on:
pull_request:
# any change in the examples folder will trigger check for the corresponding example.
paths:
- 'examples/**'
# run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
schedule:
- cron: '0 16 * * 6'
jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
detect-changed-example:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
name: Detect changed example files
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Get all changed example files
id: changed-files
uses: tj-actions/changed-files@v35
with:
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
- name: setup matrix
id: setup-matrix
run: |
changedFileName=""
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
changedFileName="${file}:${changedFileName}"
done
echo "$changedFileName was changed"
res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName`
echo "All changed examples are $res"
if [ "$res" = "[]" ]; then
echo "anyChanged=false" >> $GITHUB_OUTPUT
echo "matrix=null" >> $GITHUB_OUTPUT
else
dirs=$( IFS=',' ; echo "${res[*]}" )
echo "anyChanged=true" >> $GITHUB_OUTPUT
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
fi
# If no file is changed, it will prompt an error and shows the matrix do not have value.
check-changed-example:
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
needs.detect-changed-example.outputs.anyChanged == 'true'
name: Test the changed example
needs: detect-changed-example
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
steps:
- uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
pip install -v .
- name: Test the example
run: |
example_dir=${{ matrix.directory }}
cd "${PWD}/examples/${example_dir}"
bash test_ci.sh
env:
NCCL_SHM_DISABLE: 1
# This is for all files' weekly check. Specifically, this job is to find all the directories.
matrix_preparation:
if: |
github.repository == 'hpcaitech/ColossalAI' &&
github.event_name == 'schedule'
name: Prepare matrix for weekly check
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: setup matrix
id: setup-matrix
run: |
res=`python .github/workflows/scripts/example_checks/check_example_weekly.py`
all_loc=$( IFS=',' ; echo "${res[*]}" )
echo "Found the examples: $all_loc"
echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT
weekly_check:
if: |
github.repository == 'hpcaitech/ColossalAI' &&
github.event_name == 'schedule'
name: Weekly check all examples
needs: matrix_preparation
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
timeout-minutes: 10
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
pip install -v .
- name: Traverse all files
run: |
example_dir=${{ matrix.diretory }}
echo "Testing ${example_dir} now"
cd "${PWD}/examples/${example_dir}"
bash test_ci.sh
env:
NCCL_SHM_DISABLE: 1

View File

@ -0,0 +1,70 @@
name: Auto Release bdist wheel
on:
workflow_dispatch:
pull_request:
paths:
- 'version.txt'
types:
- closed
jobs:
matrix_preparation:
name: Prepare Container List
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v3
- id: set-matrix
run: |
bdist=$(cat .bdist.json | tr '\n' ' ')
echo "matrix=${bdist}" >> $GITHUB_OUTPUT
build:
name: Release bdist wheels
needs: matrix_preparation
runs-on: [self-hosted, gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.build.cuda_image }}
options: --gpus all --rm
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
# cub is for cuda 10.2
- name: Copy scripts
run: |
cp -r ./.github/workflows/scripts/* ./
# link the cache diretories to current path
ln -s /github/home/conda_pkgs ./conda_pkgs
ln -s /github/home/pip_wheels ./pip_wheels
# set the conda package path
echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc
# set safe directory
git config --global --add safe.directory /__w/ColossalAI/ColossalAI
# get cub package for cuda 10.2
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
- name: Build bdist wheel
run: |
pip install beautifulsoup4 requests packaging
python ./build_colossalai_wheel.py --torch_version $TORCH_VERSIONS
env:
TORCH_VERSIONS: ${{ matrix.build.torch_version }}
- name: 🚀 Deploy
uses: garygrossgarten/github-action-scp@release
with:
local: all_dist
remote: ${{ secrets.PRIVATE_PYPI_DIR }}
host: ${{ secrets.PRIVATE_PYPI_HOST }}
username: ${{ secrets.PRIVATE_PYPI_USER }}
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}

View File

@ -20,15 +20,26 @@ jobs:
- uses: actions/checkout@v2
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed files
id: find-changed-files
uses: tj-actions/changed-files@v35
with:
since_last_remote_commit: true
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
files: |
op_builder/**
colossalai/kernel/**
setup.py
- name: List changed files
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
@ -75,12 +86,26 @@ jobs:
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest tests
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: Collate artifact
env:
PR_NUMBER: ${{ github.event.number }}
run: |
mkdir report
echo $PR_NUMBER > ./report/pr_number
mv coverage.xml ./report
- name: Upload test coverage artifact
uses: actions/upload-artifact@v3
with:
name: report
path: report/
- name: Store Cache
run: |
# -p flag is required to preserve the file timestamp to avoid ninja rebuild

View File

@ -1,119 +0,0 @@
name: Test Example
on:
pull_request:
# So only the changes in examples folder will trigger jobs below.
paths:
- 'examples/**'
# run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
schedule:
- cron: '0 16 * * 6'
jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
detect-changed-example:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
name: Check out all files
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Get all changed example files
id: changed-files
uses: tj-actions/changed-files@v35
# Using this can trigger action each time a PR is submitted.
with:
since_last_remote_commit: true
- name: setup matrix
id: set-matrix
run: |
changedFileName=""
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
changedFileName="${file}:${changedFileName}"
done
echo "$changedFileName was changed"
res=`python .github/workflows/scripts/changed_example.py --fileNameList $changedFileName`
echo "All changed files are $res"
loc=$( IFS=',' ; echo "${res[*]}" )
echo "$loc"
echo "::set-output name=matrix::{\"loc\":$(echo "$loc")}"
# If no file is changed, it will prompt an error and shows the matrix do not have value.
check-all-changed-files:
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
name: Test each changed example files
needs: detect-changed-example
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependancies
run: |
pip install -r ./requirements/requirements.txt
pip install colossalai
- name: List all changed example files
run: |
res=${{ matrix.loc }}
cd "${PWD}/examples/${res}"
bash test_ci.sh
# This is for all files' weekly check. Specifically, this job is to find all the directories.
matrix_preparation:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
name: Prepare Directory List for All files
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: setup matrix
id: set-matrix
run: |
res=`python .github/workflows/scripts/weekly_check_example.py`
all_loc=$( IFS=',' ; echo "${res[*]}" )
echo "$all_loc"
echo "::set-output name=matrix::{\"all_loc\":$(echo "$all_loc")}"
weekly_check:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
name: Weekly check all examples
needs: matrix_preparation
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Install the requirements
run: |
pip install -r ./requirements/requirements.txt
pip install colossalai
- name: Traverse all files
run: |
dir=${{ matrix.all_loc }}
echo "${dir} is current directory"
cd "${PWD}/examples/${dir}"
bash test_ci.sh

View File

@ -1,4 +1,4 @@
name: Compatibility Test
name: Dispatch Compatibility Test
on:
workflow_dispatch:

View File

@ -8,7 +8,7 @@ on:
required: true
jobs:
manual_check_matrix_preparation:
matrix_preparation:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
@ -16,31 +16,24 @@ jobs:
name: Check the examples user want
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix-1.outputs.matrix }}
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Get manual directories
id: set-matrix-1
- name: Set up matrix
id: set-matrix
env:
check_dir: ${{ inputs.example_directory }}
run: |
all_mannual_check_dir=()
for cdi in $check_dir
do
all_mannual_check_dir+=("\"${cdi}\"")
done
man_loc=$( IFS=',' ; echo "${all_mannual_check_dir[*]}" )
res=`python .github/workflows/scripts/input_check_example.py --fileNameList $man_loc`
echo "${res} is file existance. 1 for all exist, -1 for at least one file not exist."
if [ res == -1 ];then
exit(1)
res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir`
if [ res == "failure" ];then
exit -1
fi
man_loc="[${man_loc}]"
echo "$man_loc"
echo "::set-output name=matrix::{\"man_loc\":$(echo "$man_loc")}"
dirs="[${check_dir}]"
echo "Testing examples in $dirs"
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
manual_check:
test_example:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
@ -52,16 +45,19 @@ jobs:
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Install the requirements
- name: Install Colossal-AI
run: |
pip install -r ./requirements/requirements.txt
pip install colossalai
- name: Traverse all files
pip install -v .
- name: Test the example
run: |
dir=${{ matrix.man_loc }}
echo "${dir} is current directory"
dir=${{ matrix.directory }}
echo "Testing ${dir} now"
cd "${PWD}/examples/${dir}"
bash test_ci.sh
env:
NCCL_SHM_DISABLE: 1

View File

@ -8,11 +8,10 @@ on:
types:
- closed
jobs:
release:
name: Draft Release Post
if: github.repository == 'hpcaitech/ColossalAI'
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

71
.github/workflows/pre_commit.yml vendored Normal file
View File

@ -0,0 +1,71 @@
name: pre-commit
on:
pull_request:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}
# the PR branch and the hpcaitech/colossal-ai main branch
# must share a common commit, we need to locate that commit,
# which is the commit checked-out or forked when the PR branch is created
# such that we can look for files changed since that commit
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed files
id: find-changed-files
uses: tj-actions/changed-files@v35
with:
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
- name: List all changed files
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
echo "$file was changed"
done
- uses: actions/setup-python@v3
- name: Cache pre-commit hooks
uses: actions/cache@v3
with:
path: ~/.cache/pre-commit
key: ${{ runner.os }}-pre-commit-hooks
- name: Set up pre-commit
run: |
pip install pre-commit
pre-commit install
- name: Run pre-commit on Changed Files
id: precommit
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
echo "======= running pre-commit on ${file} ======="
pre-commit run --files $file
done
- name: Save PR number
if: always()
env:
PR_NUMBER: ${{ github.event.number }}
run: |
mkdir -p ./pr
echo $PR_NUMBER > ./pr/pr_number
- uses: actions/upload-artifact@v3
if: always()
with:
name: pr_number
path: pr/

View File

@ -2,13 +2,16 @@ name: Publish Docker Image to DockerHub
on:
workflow_dispatch:
release:
types: [published]
pull_request:
paths:
- 'version.txt'
types:
- closed
jobs:
release:
name: Publish Docker Image to DockerHub
if: github.repository == 'hpcaitech/ColossalAI'
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: "hpcaitech/docker-in-docker:latest"
@ -18,23 +21,17 @@ jobs:
with:
fetch-depth: 0
- name: Build Docker
id: build
run: |
version=$(cat version.txt)
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker
tag=hpcaitech/colossalai:$version
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker
echo "tag=${tag}" >> $GITHUB_OUTPUT
- name: Log in to Docker Hub
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
with:
images: hpcaitech/colossalai
- name: Build and push Docker image
uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Push Docker image
run: |
docker push ${{ steps.build.outputs.tag }}

View File

@ -1,73 +1,29 @@
name: Release bdist wheel for Nightly versions
name: Publish Nightly Version to PyPI
on:
schedule:
# run at 00:00 of every Sunday
- cron: '0 0 * * 6'
workflow_dispatch:
schedule:
- cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
jobs:
matrix_preparation:
name: Prepare Container List
build-n-publish:
if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI'
name: Build and publish Python 🐍 distributions 📦 to PyPI
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
timeout-minutes: 20
steps:
- id: set-matrix
run: |
matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]"
echo $matrix
echo "::set-output name=matrix::{\"container\":$(echo $matrix)}"
- uses: actions/checkout@v2
build:
name: Release bdist wheels
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
runs-on: [self-hosted, gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
# cub is for cuda 10.2
- name: Copy scripts and checkout
run: |
cp -r ./.github/workflows/scripts/* ./
ln -s /github/home/pip_wheels ./pip_wheels
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
- name: Build bdist wheel
run: |
pip install beautifulsoup4 requests packaging
python ./build_colossalai_wheel.py --nightly
- name: 🚀 Deploy
uses: garygrossgarten/github-action-scp@release
with:
local: all_dist
remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
host: ${{ secrets.PRIVATE_PYPI_HOST }}
username: ${{ secrets.PRIVATE_PYPI_USER }}
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
remove_old_build:
name: Remove old nightly build
runs-on: ubuntu-latest
needs: build
steps:
- name: executing remote ssh commands using password
uses: appleboy/ssh-action@master
env:
BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
with:
host: ${{ secrets.PRIVATE_PYPI_HOST }}
username: ${{ secrets.PRIVATE_PYPI_USER }}
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
envs: BUILD_DIR
script: |
cd $BUILD_DIR
find . -type f -mtime +0 -exec rm -f {} +
script_stop: true
- uses: actions/setup-python@v2
with:
python-version: '3.8.14'
- run: NIGHTLY=1 python setup.py sdist build
# publish to PyPI if executed on the main branch
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
verbose: true

View File

@ -0,0 +1,67 @@
name: Report Precommit Failure
on:
workflow_run:
workflows: [pre-commit]
types:
- completed
jobs:
# comment with a message on how to do pre-commit
# if the pre-commit check was not passed
report-precommit-failure:
runs-on: ubuntu-latest
if: ${{ github.event.workflow_run.conclusion == 'failure' }}
steps:
- name: 'Download artifact'
uses: actions/github-script@v6
with:
script: |
let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: context.payload.workflow_run.id,
});
let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
return artifact.name == "pr_number"
})[0];
let download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: matchArtifact.id,
archive_format: 'zip',
});
let fs = require('fs');
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/pr_number.zip`, Buffer.from(download.data));
- name: 'Unzip artifact'
run: unzip pr_number.zip
- name: 'Comment on PR'
uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
let fs = require('fs');
let issue_number = Number(fs.readFileSync('./pr_number'));
let owner = context.repo.owner;
let repo = context.repo.repo;
let run_id = context.payload.workflow_run.id;
let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}`
let body = `
Your pre-commit check failed, follow the steps to run pre-commit on your file for code style consistency.
1. install pre-commit via "pip install pre-commit"
2. install pre-commit hooks via "pre-commit install"
3. run pre-commit on file with format error via "pre-commit run --files path" by replacing "path" with the actual file path
4. commit and push to your branch
View your job at ${run_url}.
Read our "CONTRIBUTING.md" for more reference to the code style.
`;
await github.rest.issues.createComment({
owner: owner,
repo: repo,
issue_number: issue_number,
body: body
});

View File

@ -0,0 +1,74 @@
name: Report Test Coverage
on:
workflow_run:
workflows: [Build]
types:
- completed
jobs:
report-test-coverage:
runs-on: ubuntu-latest
steps:
- name: 'Download artifact'
uses: actions/github-script@v6
with:
script: |
let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: context.payload.workflow_run.id,
});
let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
return artifact.name == "report"
})[0];
let download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: matchArtifact.id,
archive_format: 'zip',
});
let fs = require('fs');
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/report.zip`, Buffer.from(download.data));
- name: 'Unzip artifact'
run: |
unzip report.zip
- name: Code Coverage Report
uses: irongut/CodeCoverageSummary@v1.3.0
with:
filename: coverage.xml
badge: true
format: markdown
hide_branch_rate: false
hide_complexity: false
indicators: true
output: both
thresholds: '80 90'
- name: Make Coverage Report Collapsable
run: |
sed -i '2 i <details>' code-coverage-results.md
sed -i '3 i <summary>Click me to view the complete report</summary>' code-coverage-results.md
echo "</details>" >> code-coverage-results.md
- name: 'Comment on PR'
uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
let fs = require('fs');
let issue_number = Number(fs.readFileSync('./pr_number'));
let owner = context.repo.owner;
let repo = context.repo.repo;
let run_id = context.payload.workflow_run.id;
let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}`
let body = fs.readFileSync('./code-coverage-results.md', {encoding:'utf8', flag:'r'})
await github.rest.issues.createComment({
owner: owner,
repo: repo,
issue_number: issue_number,
body: body
});

View File

@ -0,0 +1,27 @@
import argparse
import os
def check_inputs(input_list):
for path in input_list:
real_path = os.path.join('examples', path)
if not os.path.exists(real_path):
return False
return True
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
args = parser.parse_args()
name_list = args.fileNameList.split(",")
is_correct = check_inputs(name_list)
if is_correct:
print('success')
else:
print('failure')
if __name__ == '__main__':
main()

View File

@ -5,9 +5,9 @@ def show_files(path, all_files):
# Traverse all the folder/file in current directory
file_list = os.listdir(path)
# Determine the element is folder or file. If file, pass it into list, if folder, recurse.
for file in file_list:
for file_name in file_list:
# Get the abs directory using os.path.join() and store into cur_path.
cur_path = os.path.join(path, file)
cur_path = os.path.join(path, file_name)
# Determine whether folder
if os.path.isdir(cur_path):
show_files(cur_path, all_files)
@ -26,9 +26,8 @@ def main():
for file_loc in contents:
split_loc = file_loc.split('/')
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
if len(split_loc) - split_loc.index('examples') >= 3:
tmp_loc = split_loc[(split_loc.index('examples') + 1):(split_loc.index('examples') + 3)]
re_loc = join(tmp_loc, '/')
if len(split_loc) >= 4:
re_loc = '/'.join(split_loc[1:3])
if re_loc not in all_loc:
all_loc.append(re_loc)
print(all_loc)

View File

@ -3,14 +3,19 @@ import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--fileNameList', type=str)
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
args = parser.parse_args()
name_list = args.fileNameList.split(":")
folder_need_check = set()
for loc in name_list:
# Find only the sub-folder of 'example' folder
# Find only the sub-sub-folder of 'example' folder
# the examples folder structure is like
# - examples
# - area
# - application
# - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
folder_need_check.add(loc.split("/")[1] + "/" + loc.split("/")[2])
folder_need_check.add('/'.join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values.
print(list(folder_need_check))

View File

@ -1,23 +0,0 @@
import argparse
import os
def detect_correct(loc_li):
for loc in loc_li:
real_loc = 'examples/' + eval(loc)
if not os.path.exists(real_loc):
return -1
return 1
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--fileNameList', type=str)
args = parser.parse_args()
name_list = args.fileNameList.split(",")
result = detect_correct(name_list)
print(result)
if __name__ == '__main__':
main()

18
.github/workflows/translate_comment.yml vendored Normal file
View File

@ -0,0 +1,18 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: false
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

4
.gitignore vendored
View File

@ -151,3 +151,7 @@ colossalai/version.py
# ignore python interface defition file
.pyi
# ignore coverage test file
coverage.lcov
coverage.xml

View File

@ -5,10 +5,10 @@
Colossal-AI: 一个面向大模型时代的通用深度学习系统
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
<a href="https://www.colossalai.org/"> 文档 </a> |
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
<a href="https://www.colossalai.org/"> 文档 </a> |
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
@ -35,7 +35,7 @@
<li><a href="#为何选择-Colossal-AI">为何选择 Colossal-AI</a> </li>
<li><a href="#特点">特点</a> </li>
<li>
<a href="#并行训练样例展示">并行训练样例展示</a>
<a href="#并行训练样例展示">并行训练样例展示</a>
<ul>
<li><a href="#GPT-3">GPT-3</a></li>
<li><a href="#GPT-2">GPT-2</a></li>
@ -47,14 +47,14 @@
</ul>
</li>
<li>
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
<ul>
<li><a href="#GPT-2-Single">GPT-2</a></li>
<li><a href="#PaLM-Single">PaLM</a></li>
</ul>
</li>
<li>
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
<ul>
<li><a href="#GPT-3-Inference">GPT-3</a></li>
<li><a href="#OPT-Serving">1750亿参数OPT在线推理服务</a></li>
@ -62,7 +62,7 @@
</ul>
</li>
<li>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
<ul>
<li><a href="#AIGC">AIGC: 加速 Stable Diffusion</a></li>
<li><a href="#生物医药">生物医药: 加速AlphaFold蛋白质结构预测</a></li>
@ -131,7 +131,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/(updated)GPT-2.png" width=800>
- 用相同的硬件训练24倍大的模型
- 超3倍的吞吐量
- 超3倍的吞吐量
### BERT
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BERT.png" width=800/>
@ -145,7 +145,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型由于完全公开了预训练参数权重因此促进了下游任务和应用部署的发展。
- 加速45%仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
- 加速45%仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/zh-Hans/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md)
请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI-Examples) 以了解详情。
@ -199,7 +199,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_serving.png" width=800/>
</p>
- [OPT推理服务](https://service.colossalai.org/opt): 无需注册免费体验1750亿参数OPT在线推理服务
- [OPT推理服务](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/zh-Hans/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md): 无需注册免费体验1750亿参数OPT在线推理服务
<p id="BLOOM-Inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
@ -255,6 +255,28 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
## 安装
### 从PyPI安装
您可以用下面的命令直接从PyPI上下载并安装Colossal-AI。我们默认不会安装PyTorch扩展包
```bash
pip install colossalai
```
但是如果你想在安装时就直接构建PyTorch扩展您可以设置环境变量`CUDA_EXT=1`.
```bash
CUDA_EXT=1 pip install colossalai
```
**否则PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。**
与此同时我们也每周定时发布Nightly版本这能让你提前体验到新的feature和bug fix。你可以通过以下命令安装Nightly版本。
```bash
pip install colossalai-nightly
```
### 从官方安装
您可以访问我们[下载](https://www.colossalai.org/download)页面来安装Colossal-AI在这个页面上发布的版本都预编译了CUDA扩展。
@ -274,10 +296,10 @@ pip install -r requirements/requirements.txt
pip install .
```
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装):
我们默认在`pip install`时不安装PyTorch扩展而是在运行时临时编译如果你想要提前安装这些扩展的话在使用融合优化器时会用到可以使用一下命令。
```shell
NO_CUDA_EXT=1 pip install .
CUDA_EXT=1 pip install .
```
<p align="right">(<a href="#top">返回顶端</a>)</p>
@ -327,6 +349,11 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
<p align="right">(<a href="#top">返回顶端</a>)</p>
## CI/CD
我们使用[GitHub Actions](https://github.com/features/actions)来自动化大部分开发以及部署流程。如果想了解这些工作流是如何运行的,请查看这个[文档](.github/workflows/README.md).
## 引用我们
```
@ -338,4 +365,6 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
}
```
Colossal-AI 已被 [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/) 等顶级会议录取为官方教程。
<p align="right">(<a href="#top">返回顶端</a>)</p>

View File

@ -149,7 +149,7 @@ distributed training and inference in a few lines.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights.
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/en/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md)
Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details.
@ -202,7 +202,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_serving.png" width=800/>
</p>
- [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
- [OPT Serving](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/en/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
<p id="BLOOM-Inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
@ -257,9 +257,32 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
## Installation
### Install from PyPI
You can easily install Colossal-AI with the following command. **By defualt, we do not build PyTorch extensions during installation.**
```bash
pip install colossalai
```
However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
```bash
CUDA_EXT=1 pip install colossalai
```
**Otherwise, CUDA kernels will be built during runtime when you actually need it.**
We also keep release the nightly version to PyPI on a weekly basis. This allows you to access the unreleased features and bug fixes in the main branch.
Installation can be made via
```bash
pip install colossalai-nightly
```
### Download From Official Releases
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions.
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built PyTorch extensions.
### Download From Source
@ -270,9 +293,6 @@ You can visit the [Download](https://www.colossalai.org/download) page to downlo
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI
# install dependency
pip install -r requirements/requirements.txt
# install colossalai
pip install .
```
@ -333,6 +353,11 @@ Thanks so much to all of our amazing contributors!
<p align="right">(<a href="#top">back to top</a>)</p>
## CI/CD
We leverage the power of [GitHub Actions](https://github.com/features/actions) to automate our development, release and deployment workflows. Please check out this [documentation](.github/workflows/README.md) on how the automated workflows are operated.
## Cite Us
```
@ -344,4 +369,6 @@ Thanks so much to all of our amazing contributors!
}
```
Colossal-AI has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc.
<p align="right">(<a href="#top">back to top</a>)</p>

View File

@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if 'activation_checkpoint' in user_node.meta:
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
if 'activation_checkpoint' in node.meta:
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
return gm

View File

@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys())
for user in user_list:

View File

@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function.
'''
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
'''
Args:
@ -59,18 +60,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
pass
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
'''
This method is used to search the best logical mesh shape for the given world size
based on the alpha_beta_dict.
For example:
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
'''
# TODO: implement this function
return (world_size, 1)
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
@ -93,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return strategies_constructor
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
'''
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
@ -109,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return solution
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor):
'''
This method is used to transform the original graph to the sharded graph.
@ -127,39 +116,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None):
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
'''
This method is used to initialize the device mesh.
Args:
world_size(optional): the size of device mesh. If the world_size is -1,
world_size: the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine.
physical_devices: the physical devices used to initialize the device mesh.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
device1d = [i for i in range(world_size)]
if physical_devices is None:
physical_devices = [i for i in range(world_size)]
physical_mesh = torch.tensor(physical_devices)
if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
alpha_beta_dict = profile_alpha_beta(device1d)
ab_profiler = AlphaBetaProfiler(physical_devices)
alpha_beta_dict = ab_profiler.alpha_beta_dict
else:
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
if logical_mesh_shape is None:
if logical_mesh_shape is None and logical_mesh_id is None:
# search for the best logical mesh shape
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
logical_mesh_id = ab_profiler.search_best_logical_mesh()
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
logical_mesh_shape = logical_mesh_id.shape
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
elif logical_mesh_shape is not None and logical_mesh_id is None:
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
physical_mesh = torch.tensor(device1d)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
@ -192,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer = ColoTracer()
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh)
if load_solver_solution:
@ -224,6 +230,7 @@ def autoparallelize(model: nn.Module,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
@ -245,6 +252,7 @@ def autoparallelize(model: nn.Module,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -254,7 +262,9 @@ def autoparallelize(model: nn.Module,
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
@ -263,7 +273,7 @@ def autoparallelize(model: nn.Module,
device_mesh,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)

View File

@ -11,6 +11,7 @@ from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
@ -27,5 +28,5 @@ __all__ = [
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption'
]

View File

@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
return OperationDataType.ARG
def _get_arg_value(idx):
non_tensor = False
if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta')
non_tensor = True
else:
# this is in fact a real data like int 1
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
return meta_data
non_tensor = True
input_meta_data = _get_arg_value(0)
other_meta_data = _get_arg_value(1)
return meta_data, non_tensor
input_meta_data, non_tensor_input = _get_arg_value(0)
other_meta_data, non_tensor_other = _get_arg_value(1)
output_meta_data = self.node._meta_data
# we need record op_data with non-tensor data in this list,
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items():
if not isinstance(op_data.data, torch.Tensor):
if op_data in self.non_tensor_list:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data)
else:
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)

View File

@ -5,6 +5,7 @@ import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -35,12 +36,14 @@ class NodeHandler(ABC):
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shard_option = shard_option
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
@ -181,6 +184,21 @@ class NodeHandler(ABC):
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
check_sharding_spec_validity(sharding_spec, op_data.data)
remove_strategy_list = []
for strategy in self.strategies_vector:
shard_level = 0
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
shard_level += len(shard_axis)
if self.shard_option == ShardOption.SHARD and shard_level == 0:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
remove_strategy_list.append(strategy)
for strategy in remove_strategy_list:
self.strategies_vector.remove(strategy)
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:

View File

@ -0,0 +1,17 @@
from enum import Enum
__all__ = ['ShardOption']
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD = 0
SHARD = 1
FULL_SHARD = 2

View File

@ -0,0 +1,523 @@
from typing import Any, Dict, Iterable, List, Tuple
import torch
import colossalai
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
if CODEGEN_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
"""
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
Args:
chunk_dim (int)
chunk_indice_name (str): chunk indice name
shape (List): node shape
Returns:
new_shape (str): return slice
"""
new_shape = "["
for idx, _ in enumerate(shape):
if idx == chunk_dim:
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
else:
new_shape += ":"
new_shape += ", "
new_shape = new_shape[:-2] + "]"
return new_shape
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
"""
Generate chunk loop start
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
chunk_size = 32
for chunk_idx in range(0, 100, 32):
......
Args:
chunk_input (List[Node]): chunk input node
chunk_output (Node): chunk output node
chunk_ouput_dim (int): chunk output node chunk dim
chunk_size (int): chunk size. Defaults to 2.
Returns:
context (str): generated str
"""
input_node = chunk_input[0]
out_shape = get_node_shape(chunk_output)
out_str = str(list(out_shape))
context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
(out_str, input_node.name, input_node.name, chunk_size))
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
return context
def _gen_loop_end(
chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
chunk_outputs: Node,
chunk_outputs_dim: int,
node_list: List[Node],
) -> str:
"""
Generate chunk loop end
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
output_node = chunk_result; xx = None; xx = None
Args:
chunk_inputs (List[Node]): chunk input node
chunk_non_compute_inputs (List[Node]): input node without chunk
chunk_outputs (Node): chunk output node
chunk_outputs_dim (int): chunk output node chunk dim
node_list (List)
Returns:
context (str): generated str
"""
chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s; %s = None\n" % (
chunk_slice,
chunk_outputs_name,
chunk_outputs_name,
)
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
# determine if its the last use for chunk input
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
context += "; %s = None" % chunk_input.name
context += "\n"
return context
def _replace_name(context: str, name_from: str, name_to: str) -> str:
"""
replace node name
"""
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
for p in patterns:
source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1]
if source in context:
context = context.replace(source, target)
break
return context
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
"""
replace reshape size, some may have changed due to chunk
"""
if node_name not in reshape_size_dict:
return context
context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
return context
def _replace_ones_like(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
node_idx: int,
node: Node,
body: List[str],
) -> List[str]:
"""
add chunk slice for new tensor op such as ones like
"""
if "ones_like" in node.name:
meta_node = search_chunk.trace_indice.node_list[node_idx]
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
def _replace_input_node(
chunk_inputs: List[Node],
region_idx: int,
chunk_inputs_dim: Dict,
node_idx: int,
body: List[str],
) -> List[str]:
"""
add chunk slice for input nodes
"""
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
return body
def emit_code_with_chunk(
body: List[str],
nodes: Iterable[Node],
emit_node_func,
delete_unused_value_func,
search_chunk: SearchChunk,
chunk_infos: List,
):
"""
Emit code with chunk according to chunk_infos.
It will generate a for loop in chunk regions, and
replace inputs and outputs of regions with chunked variables.
Args:
body: forward code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
search_chunk: the class to search all chunks
chunk_infos: store all information about all chunks.
"""
node_list = list(nodes)
# chunk region
chunk_starts = [i["region"][0] for i in chunk_infos]
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
node_idx = 0
region_idx = 0
within_chunk_region = False
while node_idx < len(node_list):
node = node_list[node_idx]
# if is chunk start, generate for loop start
if node_idx in chunk_starts:
within_chunk_region = True
region_idx = chunk_starts.index(node_idx)
body.append(
_gen_loop_start(
chunk_inputs[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
))
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
# ones like
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassgin reshape size
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
node_list,
))
within_chunk_region = False
node_idx += 1
if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False) -> None:
super().__init__()
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
if print_progress:
get_logger().info("AutoChunk start codegen")
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
# normalize the name hint to get a proper identifier
global_name = namespace.create_name(name_hint, obj)
if global_name in globals_:
assert globals_[global_name] is obj
return global_name
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return "()"
typename = _type_repr(o)
if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
if len(args) == 0:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]'
else:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
delete_free_var_from_last_use(user_to_last_uses)
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body, to_keep=[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == "placeholder":
return
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete):
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})")
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
assert isinstance(node.args, tuple)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
and node.args[1].isidentifier() and len(node.args) == 2):
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
return
body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(
body,
nodes,
emit_node,
delete_unused_values,
self.search_chunk,
self.chunk_infos,
)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append("pass\n")
if len(wrapped_fns) > 0:
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
for name, value in self.additional_globals():
add_global(name, value)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = "".join(ckpt_func) + prologue
prologue = prologue
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
# print(fn_code)
return PythonCode(fn_code, globals_)

View File

@ -0,0 +1,323 @@
import copy
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
class EstimateMemory(object):
"""
Estimate memory with chunk
"""
def __init__(self) -> None:
pass
def _get_meta_node_size(self, x):
x = x.meta["tensor_meta"]
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
return x
def _get_output_node(self, n):
out_size = activation_size(n.meta["fwd_out"])
out_node = [n.name] if out_size > 0 else []
return out_size, out_node
def _get_output_node_size(self, n):
return self._get_output_node(n)[0]
def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
if n.op == "placeholder" and get_node_shape(n) is not None:
new_active.append(n.name)
for i in new_active:
if i not in active_list and get_node_shape(n) is not None:
active_list.append(i)
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
delete_size = 0
delete_node = []
if user.op not in ("output",):
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
if to_keep is not None:
keep_list = []
for n in nodes_to_delete:
if n.name in to_keep:
keep_list.append(n)
for n in keep_list:
if n in nodes_to_delete:
nodes_to_delete.remove(n)
if len(nodes_to_delete):
out_node = [self._get_output_node(i) for i in nodes_to_delete]
delete_size = sum([i[0] for i in out_node])
for i in range(len(out_node)):
if out_node[i][0] > 0:
delete_node.append(out_node[i][1][0])
elif nodes_to_delete[i].op == "placeholder":
delete_node.append(nodes_to_delete[i].name)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return delete_size, delete_node
def _get_delete_node_size(self, user, user_to_last_uses, to_keep):
return self._get_delete_node(user, user_to_last_uses, to_keep)[0]
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
for i in delete_node:
if i in active_list:
active_list.remove(i)
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
out_node = [self._get_output_node(i) for i in nodes_to_delete]
delete_size = sum([i[0] for i in out_node])
return delete_size
def _get_last_usr(self, nodes):
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return user_to_last_uses
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
mem = 0
not_contiguous_ops = ["permute"]
inherit_contiguous_ops = ["transpose", "view"]
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
mem += self._get_output_node_size(n)
elif node.op == "call_module":
for n in node.args:
if n in not_contiguous_list:
# module will just make origin tensor to contiguous
if delete:
not_contiguous_list.remove(n)
elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
return mem
def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
if node not in chunk_node_dim:
return 1.0
node_shape = get_node_shape(node)
chunk_dim = chunk_node_dim[node]["chunk_dim"]
if chunk_dim is None:
return 1.0
else:
return float(chunk_size) / node_shape[chunk_dim]
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if user.op in ("placeholder", "output"):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
delete_size = 0
for n in nodes_to_delete:
if n.name in chunk_inputs_names:
continue
delete_size += self._get_output_node_size(n) * chunk_ratio
return delete_size
def _print_mem_log(self, log, nodes, title=None):
if title:
print(title)
for idx, (l, n) in enumerate(zip(log, nodes)):
print("%s:%.2f \t" % (n.name, l), end="")
if (idx + 1) % 3 == 0:
print("")
print("\n")
def _print_compute_op_mem_log(self, log, nodes, title=None):
if title:
print(title)
for idx, (l, n) in enumerate(zip(log, nodes)):
if n.op in ["placeholder", "get_attr", "output"]:
continue
if any(i in n.name for i in ["getitem", "getattr"]):
continue
print("%s:%.2f \t" % (n.name, l), end="")
if (idx + 1) % 3 == 0:
print("")
print("\n")
def estimate_chunk_inference_mem(
self,
node_list: List,
chunk_infos=None,
print_mem=False,
):
"""
Estimate inference memory with chunk
Args:
node_list (List): _description_
chunk_infos (Dict): Chunk information. Defaults to None.
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
active_node_list = []
active_node_list_log = []
not_contiguous_list = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_names = []
if use_chunk:
chunk_regions = [i["region"] for i in chunk_infos]
chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_list):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(
node,
chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx],
)
# if node is placeholder, just add the size of the node
if node.op == "placeholder":
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2)
act_memory_peak_log.append(act_memory)
# skip output
elif node.op == "output":
continue
# no change for non compute node
elif is_non_memory_node(node):
act_memory_peak_log.append(act_memory)
# node is a compute op
# calculate tmp, output node and delete node memory
else:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
(1024**2))
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if chunk_within:
act_memory -= self._get_chunk_delete_node_size(
node,
user_to_last_uses_no_free_var,
chunk_ratio,
chunk_inputs_names,
) / (1024**2)
else:
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
chunk_inputs_names) / (1024**2)
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
act_memory -= self._get_chunk_inputs_size(
chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx],
node_list,
chunk_regions[chunk_region_idx][1],
) / (1024**2)
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
act_memory_after_node_log.append(act_memory)
active_node_list_log.append(copy.deepcopy(active_node_list))
if print_mem:
print("with chunk" if use_chunk else "without chunk")
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
def get_active_nodes(self, node_list: List) -> List:
"""
Get active nodes for every node
Args:
node_list (List): _description_
Returns:
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
active_node_list = []
active_node_list_log = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
for _, node in enumerate(node_list):
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
active_node_list_log.append(copy.deepcopy(active_node_list))
return active_node_list_log

View File

@ -0,0 +1,117 @@
from .trace_indice import TraceIndice
from .utils import find_idx_by_name
class ReorderGraph(object):
"""
Reorder node list and indice trace list
"""
def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice
self.all_reorder_map = {
i: i for i in range(len(self.trace_indice.indice_trace_list))
}
def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [
find_idx_by_name(i.name, self.trace_indice.node_list)
for i in chunk_prepose_nodes
]
# put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
if n in chunk_prepose_nodes:
continue
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos
return reorder_map
def _reorder_chunk_info(self, chunk_info, reorder_map):
# update chunk info
chunk_info["region"] = (
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
chunk_info["region"][1],
)
new_inputs_dim = []
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
new_input_dim = {}
for k, v in input_dim.items():
new_input_dim[reorder_map[k]] = v
new_inputs_dim.append(new_input_dim)
chunk_info["inputs_dim"] = new_inputs_dim
return chunk_info
def _update_all_reorder_map(self, reorder_map):
for origin_idx, map_idx in self.all_reorder_map.items():
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
self.trace_indice.node_list = new_node_list
def _reorder_idx_trace(self, reorder_map):
# reorder list
new_idx_trace_list = [
None for _ in range(len(self.trace_indice.indice_trace_list))
]
for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.indice_trace_list = new_idx_trace_list
# update compute
for idx_trace in self.trace_indice.indice_trace_list:
compute = idx_trace["compute"]
for dim_compute in compute:
for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i]
# update source
for idx_trace in self.trace_indice.indice_trace_list:
source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source):
new_dim_source = {}
for k, v in dim_source.items():
new_dim_source[reorder_map[k]] = v
source[dim_idx] = new_dim_source
def reorder_all(self, chunk_info):
if chunk_info is None:
return chunk_info
if len(chunk_info["args"]["prepose_nodes"]) == 0:
return chunk_info
reorder_map = self._get_reorder_map(chunk_info)
self._update_all_reorder_map(reorder_map)
self._reorder_idx_trace(reorder_map)
self._reorder_self_node_list(reorder_map)
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
return chunk_info
def reorder_node_list(self, node_list):
new_node_list = [None for _ in range(len(node_list))]
for old_idx, new_idx in self.all_reorder_map.items():
new_node_list[new_idx] = node_list[old_idx]
return new_node_list
def tmp_reorder(self, node_list, chunk_info):
if len(chunk_info["args"]["prepose_nodes"]) == 0:
return node_list, chunk_info
reorder_map = self._get_reorder_map(chunk_info)
# new tmp node list
new_node_list = [None for _ in range(len(node_list))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = node_list[old_idx]
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
return new_node_list, chunk_info

View File

@ -0,0 +1,319 @@
import copy
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
"""
This is the core class for AutoChunk.
It defines the framework of the strategy of AutoChunk.
Chunks will be selected one by one utill search stops.
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
5. goto 1
Attributes:
gm: graph model
print_mem (bool): print estimated memory
trace_index: trace the flow of every dim of every node to find all free dims
trace_flow: determine the region chunk strategy
reorder_graph: reorder nodes to improve chunk efficiency
estimate_memory: estimate memory with chunk
select_chunk: select the best chunk region
Args:
gm: graph model
max_memory (int): max memory in MB
print_mem (bool): print estimated memory
"""
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
self.reorder_graph,
max_memory=max_memory,
)
def _init_trace(self) -> None:
"""
find the max trace range for every node
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1]
if cur_node_idx == len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)
# nothing to limit for the first range
max_chunk_region_list = max_chunk_region_list[1:]
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
self.trace_indice.trace_indice()
def _find_peak_node(self, mem_peak: List) -> int:
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
return max_idx
def _get_free_var_idx(self) -> List:
"""
Get free var index
Returns:
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node
Chunk region starts extending from the peak node, stops where free var num is min
Args:
active_node (List): active node status for every node
peak_node_idx (int): peak memory node idx
chunk_regions (List): chunk region infos
Returns:
chunk_region_start (int)
chunk_region_end (int)
"""
free_vars = self._get_free_var_idx()
free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node]
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node_idx, -1, -1):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_start = i + 1
break
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node_idx, len(active_node)):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_end = i
break
# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
if len(start_traces) > 1:
continue
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# must have users
if len(end_node.users) == 0:
continue
# check index source align
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
# check index copmute
if not self.trace_flow.check_index_duplicate(chunk_info):
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.
Args:
max_chunk_region (Tuple)
peak_node (Node): peak memory node
Returns:
possible_chunk_region (List)
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
def _step_search(
self,
mem_peak: List[float],
active_node: List[List[Node]],
chunk_infos: List[Dict],
) -> Dict:
"""
Find one chunk region
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
Args:
mem_peak (List): peak memory for every node
active_node (List[List[Node]]): active node for every node
chunk_infos (List[Dict]): all chunk info
Returns:
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
max_chunk_region, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region
def _stop_search(self, init_mem_peak, mem_peak):
sorted_init_mem_peak = sorted(init_mem_peak)
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
return True
return False
def search_region(self) -> Dict:
"""
Search all chunk regions:
1. Estimate current memory
2. Find best chunk for current memory
3. goto 1
Returns:
chunk_infos (Dict)
"""
if self.print_progress:
get_logger().info("AutoChunk start searching chunk regions")
chunk_infos = []
(
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
mem_peak = init_mem_peak
while True:
chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
if chunk_info is None:
break
chunk_infos.append(chunk_info)
(
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
return chunk_infos

View File

@ -0,0 +1,224 @@
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_indice import TraceIndice
from .utils import is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph,
max_memory=None,
):
self.trace_indice = trace_indice
self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
def _select_best_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region(
possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
elif self.stratge == "fit_memory":
best_region = self._select_fit_memory_chunk_region(
possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
else:
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
if not self._is_legal_region(i, chunk_infos):
illegal_regions.append(i)
for i in illegal_regions:
if i in possible_chunk_regions:
possible_chunk_regions.remove(i)
if len(possible_chunk_regions) == 0:
return None
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.trace_indice.node_list, cur_region
)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
# select the min chunk len
chunk_len = [i["chunk_len"] for i in regions_dict]
best_region_idx = chunk_len.index(min(chunk_len))
best_region = regions_dict[best_region_idx]
# get max chunk size
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
return best_region
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
chunk_size = 1
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_max_mem = 0
# search a region
while cur_chunk_max_mem < self.max_memory:
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
+ 1
]
)
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
if left >= 16:
gap = 4
else:
gap = 1
chunk_info = chunk_region_dict["reorder_chunk_info"]
while right >= left + gap:
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
)
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
left = mid + gap
return left
def _get_compute_node_num(self, start, end):
count = 0
for i in self.trace_indice.node_list[start : end + 1]:
if not is_non_compute_node(i):
count += 1
return count
def _select_min_memory_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
if not self._is_legal_region(i, chunk_infos):
illegal_regions.append(i)
for i in illegal_regions:
if i in possible_chunk_regions:
possible_chunk_regions.remove(i)
if len(possible_chunk_regions) == 0:
return None
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.trace_indice.node_list, cur_region
)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
best_region = regions_dict[best_region_idx]["chunk_info"]
if best_region is not None:
best_region["chunk_size"] = 1
return best_region
def _is_legal_region(self, cur_chunk_info, chunk_infos):
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
if cur_chunk_info in chunk_infos:
return False
if chunk_region_end < chunk_region_start:
return False
for i in chunk_infos:
region = i["region"]
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True

View File

@ -0,0 +1,445 @@
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .trace_indice import TraceIndice
from .utils import (
find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
flat_list,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
"""
Check 2 given index: one index should be source of the other
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and start_dim in node_dim:
return True
# it means we meet a node outside the loop, and the node is not input node
if node_idx < start_idx:
return False
return False
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
"""
Check 2 given index: check they haven't been computed in the source trace.
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_compute = end_node_trace["compute"][end_dim]
if any(start_idx <= i <= end_idx for i in end_node_compute):
return False
return True
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
for k, v in dim_source.items():
if k == node_to_idx:
return v
return None
def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
if (input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
return node_dim
return None
def check_index_duplicate(self, chunk_infos, return_dim=False):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
if is_non_compute_node_except_placeholder(node):
continue
count = 0
duplicate_dims = []
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
duplicate_dim = []
duplicate_flag = False
dim_source = node_trace_source[node_dim]
for k, v in dim_source.items():
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
if k in input_dim_after_node and input_dim_after_node[k] in v:
duplicate_flag = True
duplicate_dim.append((k, v))
duplicate_dims.append(duplicate_dim)
if duplicate_flag:
count += 1
if count > 1:
if return_dim:
return False, duplicate_dims
else:
return False
if return_dim:
return True, None
else:
return True
def _assgin_single_node_flow(
self,
arg_node: Node,
start_idx: int,
end_idx: int,
cur_node_dim: int,
cur_node_compute: Dict,
cur_node_source: Dict,
cur_node_fix_dim: List,
all_node_info: Dict,
next_node_list: List,
) -> bool:
"""
Given the current node and one of its arg node,
this function finds out arg node's chunk dim and fix dim
Args:
arg_node (Node): input node
start_idx (int): chunk region start
end_idx (int): chunk region end
cur_node_dim (int): current node chunk dim
cur_node_compute (Dict): current node compute dict
cur_node_source (Dict): current node source dict
cur_node_fix_dim (List): current node fix dim
all_node_info (Dict): all node chunk info in the chunk region
next_node_list (List)
Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
# find arg dim
if cur_node_dim is not None:
# dim is computed
if arg_idx in cur_node_compute[cur_node_dim]:
return False
if arg_idx not in cur_node_source[cur_node_dim]:
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
else:
arg_dim = None
# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
# else add it to list
else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
next_node_list.append(arg_node)
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
next_node_list = []
for cur_node in cur_node_list:
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim is not None:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else:
cur_node_compute = cur_node_source = None
# get all valid args
arg_list = []
for arg in cur_node.all_input_nodes:
if type(arg) != type(cur_node):
continue
if is_non_compute_node(arg):
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
arg,
start_idx,
end_idx,
cur_node_chunk_dim,
cur_node_compute,
cur_node_source,
cur_node_fix_dim,
all_node_info,
next_node_list,
)
if flow_flag == False:
return None
if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
for arg in arg_list:
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
arg_shape = get_node_shape(arg)
# add all dim as fix dim except chunk dim
for i, shape in enumerate(arg_shape):
if shape != 1 and i != cur_node_chunk_dim:
if i == arg_chunk_dim:
return None
if i not in arg_fix_dim:
arg_fix_dim.append(i)
elif "einsum" in cur_node.name:
pass
elif "matmul" in cur_node.name:
pass
else:
raise NotImplementedError()
cur_node_list = next_node_list
return all_node_info
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim = []
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user):
continue
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
input_dict[user_idx] = [None]
else:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None, None
if len(input_dict) == 0:
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
# remove unchunked inputs
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes
maybe_prepose_nodes = []
for node, node_info in all_node_info.items():
if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
reverse=True,
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
tmp_cur_related_prepose_nodes = []
prepose_flag = True
# loop cur node's all arg until out of chunk
while len(tmp_cur_prepose_nodes) > 0:
if prepose_flag == False:
break
tmp_next_prepose_nodes = []
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
for cur_prepose_node in tmp_cur_prepose_nodes:
if prepose_flag == False:
break
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
else:
prepose_flag = False
break
# non compute op
else:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
if prepose_flag == False:
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
continue
else:
for n in tmp_cur_related_prepose_nodes:
if n not in prepose_nodes:
prepose_nodes.append(n)
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
return prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
# only single ouput
if len(outputs) > 1:
return None
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info = {
"region": (start_idx, end_idx),
"inputs": inputs,
"inputs_non_chunk": [],
"inputs_dim": inputs_dim,
"outputs": outputs,
"outputs_dim": end_dim,
"node_chunk_dim": all_node_info,
"args": {},
}
# move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
# reassgin reshape size, some size may have changed due to chunk
chunk_info = self._reassgin_reshape_size(chunk_info)
return chunk_info
def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
reshape_args = flat_list(node.args[1:])
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim == chunk_dim:
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
else:
if isinstance(reshape_arg, int):
new_shape += "%s, " % str(reshape_arg)
else:
new_shape += "%s, " % reshape_arg.name
new_shape = new_shape[:-2]
origin_shape = str(reshape_args)[1:-1]
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size
return chunk_info

View File

@ -0,0 +1,703 @@
import copy
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
class TraceIndice(object):
"""
Trace all indice infomation for every node.
Indice is a logical concept. Equal dims can been treated as one indice.
eg. dim(x1) = [a, b, c]
dim(x2) = [d, e, f]
and we have x3 = x1 * x2.
then a=d, b=e, c=f, due to the broadcast property,
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
This class will record every node's dims' indice, compute and source.
Attibutes:
node_list (List)
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
indice_view_list (Dict): not used for now
indice_count (int): record indice number
Args:
node_list (List)
"""
def __init__(self, node_list: List[Node]) -> None:
self.node_list = node_list
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self):
indice_trace_list = []
for n in self.node_list:
if get_node_shape(n) != None:
cur_trace = {
"indice": [None for _ in range(len(get_node_shape(n)))],
"compute": [[] for _ in range(len(get_node_shape(n)))],
"source": [{} for _ in range(len(get_node_shape(n)))],
}
else:
cur_trace = {"indice": [], "compute": [], "source": []}
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
self.active_node_list = active_node_list
def _add_indice(self):
"""
Update the count and return it. To record the idx number.
Returns:
indice_count: int
"""
self.indice_count += 1
return self.indice_count
def _del_dim(self, idx, dim_idx):
self.indice_trace_list[idx]["indice"].pop(dim_idx)
self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx, dim_idx):
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
def _transform_indice(self, node, node_dim):
node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx)))
return dims[node_dim]
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
def _inherit_all_computation(self, node_from, node_to):
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
assert len(node_from_compute) == len(node_to_compute)
for i in range(len(node_from_compute)):
self._add_source(node_from, i, node_to, i)
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
if init:
node_to_trace_source[node_to_dim] = {}
# add dim to cur new source
if node_from_idx not in node_to_trace_source[node_to_dim]:
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
else:
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
# update inputs source
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
if node_idx not in node_to_trace_source[node_to_dim]:
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
else:
for d in node_dim:
if d not in node_to_trace_source[node_to_dim][node_idx]:
node_to_trace_source[node_to_dim][node_idx].append(d)
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
if exclude == None:
exclude = []
else:
exclude = [self._transform_indice(node_to, i) for i in exclude]
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
# assert len(node_from_compute) == len(node_to_compute)
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
if self._transform_indice(node_to, i) in exclude:
continue
self._add_source(node_from, i, node_to, i)
for j in node_from_compute[i]:
if j not in node_to_compute[i]:
node_to_compute[i].append(j)
def _mark_computation(self, node, idx, dim):
"""
Mark some dims of node as computed.
Args:
node (node)
idx (int): node index
dim (list or int): dims to be marked as computed
"""
if isinstance(dim, int):
dim = [dim]
dims = list(range(len(get_node_shape(node))))
for d in dim:
cur_dim = dims[d]
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
def _find_trace_from_node(self, node):
"""
Find node idx and compute trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_dict = self.indice_trace_list[node_idx]
return node_dict
def _find_source_trace_from_node(self, node):
"""
Find node source trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_dict = self.indice_trace_list[node_idx]
return node_dict["source"]
def _find_indice_trace_from_node(self, node):
"""
Find node idx trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
"""
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node):
"""
Find node compute trace by the node.
Args:
node (node)
Returns:
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
"""
Assign node's trace as its input node.
Args:
node (node)
node_idx (int)
"""
if input_node == None:
input_node = find_first_tensor_arg(node)
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
self._inherit_all_computation(input_node, node)
def _assign_all_indice(self, node: Node, node_idx: int):
"""
Add new indice for all node's dims.
Args:
node (node)
node_idx (int)
"""
shape = node.meta["tensor_meta"].shape
if shape is None:
return
new_trace = []
for _ in shape:
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node: Node, node_idx: int):
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
2. inherit input's computation
Args:
node (node)
node_idx (int)
"""
input_node = node.args[0]
tranpose_dim = node.args[1:]
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node: Node, node_idx: int):
"""
Assign indice for permute op.
1. swap input's dim according to permute args
2. inherit input's computation
Args:
node (node)
node_idx (int)
"""
permute_dim = flat_list(node.args[1:])
input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node: Node, node_idx: int):
"""
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
2. mark equal for input node last indice, weight first dim and bias dim.
3. inherit input's computation, mark computation for last dim.
Args:
node (node)
node_idx (int)
"""
if len(node.args) == 2:
_, weight = node.args
else:
_, weight, _ = node.args
self._assign_indice_as_input(node, node_idx)
self._inherit_indice(weight, 1, node, -1)
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node: Node, node_idx: int):
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
Args:
node (node)
node_idx (int)
"""
matmul_left, matmul_right = node.args
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1)
self._mark_computation_from_node(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1])
def _assign_layernorm_indice(self, node, idx):
"""
Assign indice for layernorm op.
1. assign indice as input node
2. inherit computation and mark last 2 dims as computed.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1])
def _assign_elementwise_indice(self, node, idx):
"""
Assign indice for element-wise op (eg. relu sigmoid add mul).
1. assign indice as input node
2. inherit computation from all input nodes.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, idx)
nodes_in = []
for node_in in node.args:
if type(node_in) == type(node):
nodes_in.append(node_in)
self._mark_computation_from_node(node_in, node)
assert len(nodes_in) <= 2
def _assgin_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node)
def _assign_einsum_indice(self, node, idx):
"""
Assign indice for einsum op.
Args:
node (node)
node_idx (int)
"""
patterns = node.args[0]
input_nodes = node.args[1:]
patterns = patterns.replace(" ", "")
left, right = patterns.split("->")
left = left.split(",")
if '...' in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
replace_str = replace_list[:add_len]
right = right.replace("...", replace_str)
for ll in range(len(left)):
left[ll] = left[ll].replace("...", replace_str)
all_index = []
for i in left:
for c in i:
all_index.append(c)
all_index = set(all_index)
for right_idx, right_indice in enumerate(right):
for left_idx, left_str in enumerate(left):
if right_indice in left_str:
source_idx = left_str.index(right_indice)
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
def _assign_softmax_indice(self, node, idx):
"""
Assign indice for softmax op.
1. assign indice as input node
2. inherit computation and mark softmax dim as computed.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
dim_idx = node.args[1]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if dim_idx < 0:
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_dropout_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node: Node, node_idx: int):
"""
Assign indice for oneslike op.
1. assign new indice for all dim
Args:
node (node)
node_idx (int)
"""
self._assign_all_indice(node, node_idx)
def _assign_cat_indice(self, node: Node, node_idx: int):
"""
Assign indice for cat op.
Args:
node (node)
node_idx (int)
"""
nodes_in = flat_list(node.args[0])
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
self._add_dim(node_idx, cat_dim)
def _assign_sum_indice(self, node: Node, node_idx: int):
"""
Assign indice for sum op.
Args:
node (node)
node_idx (int)
"""
nodes_in = flat_list(node.args[0])
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
def _assign_getitem_indice(self, node: Node, node_idx: int):
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args = flat_list(node.args[1:])
flag = False
for node_arg in node_args:
node_arg_str = str(node_arg)
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
flag = True
break
if "slice" in node_arg_str:
flag = True
break
if flag == False:
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape = get_node_shape(node)
origin_idx_count = 0
new_idx_count = 0
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
for _ in range(delete_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
node_arg_str = str(node_arg)
# Ellipsis means [..., ]
if "Ellipsis" == node_arg_str:
shape_gap = len(node_shape) - len(node_args) + 1
origin_idx_count += shape_gap
new_idx_count += shape_gap
# slice(None, None, None) means all indexes
elif "slice" in node_arg_str:
if "slice(None, None, None)" != node_arg_str:
self._del_dim(node_idx, new_idx_count)
self._add_dim(node_idx, new_idx_count)
origin_idx_count += 1
new_idx_count += 1
# None means a new dim
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
elif "0" == node_arg_str:
self._del_dim(node_idx, new_idx_count)
origin_idx_count += 1
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
2. compute the real value of -1 in target shape.
3. determine changed dim, and assgin indice for generated dim.
4. log changed dim and generated dim for restore
5. inherit computation.
6. TODO: look into view list to see whether the view is associated with other,
if so assgin equal dim according to previous view.
Args:
node (node)
node_idx (int)
"""
# get data, turn into number
origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = []
unflated_args = flat_list(node.args)
for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else:
target_shape.append(unflated_args[i].meta["fwd_out"][0])
# compute the value of -1
if -1 in target_shape:
origin_product = 1
for i in origin_shape:
origin_product *= i
target_product = -1
for i in target_shape:
target_product *= i
shape_idx = target_shape.index(-1)
target_shape[shape_idx] = origin_product // target_product
# determine changed dim
len_diff = len(origin_shape) - len(target_shape)
if len_diff == 1:
# dim merge
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
dim_to = [dim_equal.index(False)]
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
self._add_dim(node_idx, -1)
elif len_diff == -1:
# dim expand
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = [dim_equal.index(False)]
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
else:
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
dim_from.reverse()
for i in dim_from:
self._del_dim(node_idx, i)
for i in dim_to:
self._add_dim(node_idx, i)
# inherit computation
compute_log = self._find_compute_trace_from_node(origin_node)
for i in dim_from:
if origin_trace[i] in compute_log:
for j in dim_to:
self._mark_computation(node, node_idx, [j])
break
# log view, not used now
view_dict = {
"idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from,
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
"dim_to": dim_to,
}
self.indice_view_list[node] = view_dict
def _clear_trace(self, node_idx: int) -> None:
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self):
for idx, node in enumerate(self.node_list):
if node.op == "placeholder":
self._assign_all_indice(node, idx)
elif node.op == "call_method":
if "transpose" in node.name:
self._assign_transpose_indice(node, idx)
elif "permute" in node.name:
self._assign_permute_indice(node, idx)
elif "view" in node.name or "reshape" in node.name:
self._assign_view_reshape_indice(node, idx)
elif "unsqueeze" in node.name:
self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
self._assign_ones_like_indice(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function":
if "linear" in node.name:
self._assign_linear_indice(node, idx)
elif "cat" in node.name:
self._assign_cat_indice(node, idx)
elif "matmul" in node.name:
self._assign_matmul_indice(node, idx)
elif "softmax" in node.name:
self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name:
self._assign_ones_like_indice(node, idx)
elif "dropout" in node.name:
self._assign_dropout_indice(node, idx)
elif "einsum" in node.name:
self._assign_einsum_indice(node, idx)
elif "sum" in node.name:
self._assign_sum_indice(node, idx)
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
self._assign_getitem_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue
else:
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]):
self._assign_layernorm_indice(node, idx)
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
# limit trace range
self._clear_trace(idx)

View File

@ -0,0 +1,132 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def get_logger():
return logger
def flat_list(inputs: Any) -> List:
"""
flat a list by recursion
"""
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
return [inputs]
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(flat_list(i))
else:
res.append(i)
return res
def find_first_tensor_arg(node: Node) -> Node:
"""
Find the first input tensor arg for a node
"""
for arg in node.args:
if type(arg) == type(node):
return arg
raise RuntimeError()
def is_non_compute_node(node: Node) -> bool:
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
return True
if "getitem" in node.name:
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
if "slice" in str(node_arg):
return False
return True
return False
def get_node_shape(node: Node) -> List:
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None
def is_non_memory_node(node: Node) -> bool:
if "getitem" in node.name:
return True
if "output" in node.op:
return True
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder(node):
if "placeholder" in node.op:
return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node):
if "output" in node.op:
return False
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name, nodes_list):
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
raise RuntimeError("name %s not found in node list" % name)
def delete_free_var_from_last_use(user_to_last_uses):
for key, value in user_to_last_uses.items():
for n in value:
if n.op == "placeholder":
user_to_last_uses[key].remove(n)
def find_chunk_all_input_nodes(nodes: List[Node]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
for node in nodes:
for input_node in node._input_nodes.keys():
if input_node not in nodes and input_node not in input_nodes:
input_nodes.append(input_node)
return input_nodes
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (input_node not in nodes and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (output_node not in nodes and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)):
output_nodes.append(node)
return input_nodes, output_nodes

View File

@ -1,5 +1,5 @@
from typing import List
import socket
from typing import List
class HostInfo:
@ -35,9 +35,14 @@ class HostInfo:
if port is None:
port = 22 # no port specified, lets just use the ssh port
hostname = socket.getfqdn(hostname)
# socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines
# thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
hostname = socket.getfqdn(hostname)
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)

View File

@ -1,8 +1,10 @@
import fabric
from .hostinfo import HostInfo, HostInfoList
from multiprocessing import Pipe, Process
from multiprocessing import connection as mp_connection
import click
import fabric
from .hostinfo import HostInfo, HostInfoList
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
@ -45,8 +47,10 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send('success')
except:
click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
send_conn.send('failure')
# shutdown

View File

@ -1,13 +1,16 @@
import click
import sys
import os
import torch
from colossalai.context import Config
from .multinode_runner import MultiNodeRunner
from .hostinfo import HostInfo, HostInfoList
import sys
from typing import List
import click
import torch
from packaging import version
from colossalai.context import Config
from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
NODE_SEP = ','
@ -15,7 +18,7 @@ NODE_SEP = ','
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
"""
Parse the hostfile to obtain a list of hosts.
A hostfile should look like:
worker-0
worker-1
@ -63,7 +66,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
device_pool (HostInfoList): a list of HostInfo objects
include_str (str): --include option passed by user, default None
exclude_str (str): --exclude option passed by user, default None
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
@ -192,7 +195,7 @@ def launch_multi_processes(args: Config) -> None:
Launch multiple processes on a single node or multiple nodes.
The overall logic can be summarized as the pseudo code below:
if hostfile given:
hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo)
@ -202,7 +205,7 @@ def launch_multi_processes(args: Config) -> None:
launch_on_multi_nodes(hostinfo)
else:
launch_on_current_node()
Args:
args (Config): the arguments taken from command line
@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
extra_launch_args=args.extra_launch_args)
runner.send(hostinfo=hostinfo, cmd=cmd)
runner.recv_from_all()
# start training
msg_from_node = runner.recv_from_all()
has_error = False
# print node status
click.echo("\n====== Training on All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
# check if a process failed
if msg == "failure":
has_error = True
# stop all nodes
runner.stop_all()
runner.recv_from_all()
# receive the stop status
msg_from_node = runner.recv_from_all()
# printe node status
click.echo("\n====== Stopping All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
# give the process an exit code
# so that it behaves like a normal process
if has_error:
sys.exit(1)
else:
sys.exit(0)

View File

@ -381,6 +381,8 @@ class AlphaBetaProfiler:
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
return mesh_alpha, mesh_beta

View File

@ -1,5 +1,6 @@
import operator
from functools import reduce
from typing import List, Tuple
import torch
import torch.distributed as dist
@ -15,7 +16,8 @@ class DeviceMesh:
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
@ -28,15 +30,21 @@ class DeviceMesh:
"""
def __init__(self,
physical_mesh_id,
mesh_shape,
mesh_alpha=None,
mesh_beta=None,
init_process_group=False,
need_flatten=True):
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
@ -54,8 +62,8 @@ class DeviceMesh:
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self.mesh_beta)
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
# self.mesh_beta)
@property
def shape(self):

View File

@ -1,24 +1,34 @@
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path
try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True
except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name: str = 'GraphModule',
ckpt_codegen: bool = True):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):

View File

@ -9,6 +9,40 @@ def pipe_split():
pass
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgcompute_split_pass, we split module by the fwd flops.
"""
mod_graph = gm.graph
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
for node in mod_graph.nodes:
total_fwd_flop += node.fwd_flop
partition_flop = total_fwd_flop // pp_size
accumulate_fwd_flop = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
total_fwd_flop = total_fwd_flop - accumulate_fwd_flop
accumulate_fwd_flop = 0
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
@ -104,8 +138,10 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
total_element_size = total_element_size - accumulate_node_size
accumulate_node_size = 0
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()

View File

@ -112,7 +112,8 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
n.meta['type'] = type(result)
# retain the autograd graph

View File

@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
aten.sub.Tensor,
aten.sub_.Tensor,
# activation op
aten.hardswish.default,
@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.where.self,
aten.zero_.default,
aten.zeros_like.default,
]
aten.fill_.Scalar
] # yapf: disable
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit

View File

@ -1,6 +1,4 @@
import uuid
from copy import deepcopy
from typing import Optional
import torch
from torch.types import _bool, _device, _dtype
@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor):
_tensor: torch.Tensor
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
@ -47,7 +43,7 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=fake_device if fake_device is not None else elem.device,
device=fake_device if fake_device is not None else torch.device('cpu'),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
@ -59,8 +55,8 @@ class MetaTensor(torch.Tensor):
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@ -76,13 +72,13 @@ class MetaTensor(torch.Tensor):
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
@ -118,23 +114,24 @@ class MetaTensor(torch.Tensor):
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, _device):
device = arg
if 'device' in kwargs:
device = kwargs['device']
result = super().to(*args, **kwargs)
if device is not None:
result = MetaTensor(result, fake_device=device)
return result
fake_device = None
def replace(x):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
return 'meta'
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
def cuda(self, *args, **kwargs):
if self.device.type == 'cuda':
return self.to(*args, **kwargs)
return self.to(*args, device='cuda', **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)

View File

@ -13,6 +13,7 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule:
"""
Symbolic tracing API
@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)

View File

@ -1,7 +1,7 @@
import enum
import functools
import operator
import inspect
import operator
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@ -286,7 +286,6 @@ class ColoTracer(Tracer):
self.graph.lint()
return self.graph
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
@ -316,7 +315,6 @@ class ColoTracer(Tracer):
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
@ -385,18 +383,23 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule:
if is_compatible_with_meta():
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs)
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder':
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
_truncate_suffix(target), None)
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
elif kind == 'get_attr':
attr_itr = root
atoms = target.split(".")
@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
else:
if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
**tree_map(unwrap_fn, kwargs))
elif kind == 'call_module':
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
meta_out = None
return meta_out
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target]
@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
return meta_out
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
result_graph = Graph()
value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
if handle is not None:
handle.generate()
for node_inserted in tracer.graph.nodes:
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node
else:
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
del tracer
gm.graph = result_graph
gm.recompile()
meta_prop_pass(gm, root_model, meta_args)

View File

@ -6,17 +6,14 @@ import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter
def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
from colossalai.utils import is_ddp_ignored
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
@ -56,7 +53,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
if is_ddp_ignored(param):
continue
param_key = param.process_group.dp_world_size()

View File

@ -6,8 +6,14 @@ import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored
def safe_div(a, b):
if a == 0:
return 0
return a / b
def init_chunk_manager(model: nn.Module,
@ -16,7 +22,6 @@ def init_chunk_manager(model: nn.Module,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
if hidden_dim:
@ -34,7 +39,7 @@ def init_chunk_manager(model: nn.Module,
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
total_size = sum(params_sizes) / 1024**2
dist.barrier()
@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module,
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()

View File

@ -50,6 +50,17 @@ class GeminiManager:
self._warmup = True
self._comp_cuda_demand_time = 0
def reset_attributes(self):
self._compute_idx = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0
def is_warmup(self):
return self._warmup
def memstats(self):
"""memstats
@ -73,12 +84,7 @@ class GeminiManager:
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
self._compute_idx = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0
self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of stateful tensors according to the information provided

View File

@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.logging import get_dist_logger
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.gradient_accumulation import accumulate_gradient
from colossalai.engine.schedule import (
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
get_tensor_shape,
)
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@ -301,9 +300,9 @@ def initialize(model: nn.Module,
model = model().to(get_current_device())
# optimizer maybe a optimizer_cls
logger.warning("Initializing an non ZeRO model with optimizer class")
if isinstance(optimizer, Callable):
optimizer = optimizer(model.parameters())
logger.warning("Initializing an non ZeRO model with optimizer class")
if not use_zero:
if is_using_sequence():

View File

@ -114,6 +114,13 @@ class FusedScaleMaskSoftmax(nn.Module):
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
try:
from colossalai._C import scaled_masked_softmax
except ImportError:
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
self.scaled_masked_softmax = scaled_masked_softmax
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
@ -178,11 +185,5 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
try:
import colossalai._C.scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
def get_batch_per_block(self, sq, sk, b, np):
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)

View File

@ -19,7 +19,7 @@ class CPUAdam(NVMeOptimizer):
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``.
`CPUAdam` requires CUDA extensions which can be built during installation or runtime.
This version of CPU Adam accelates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required.

View File

@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Currently GPU-only. Requires ColossalAI to be installed via
``pip install .``.
`FusedAdam` requires CUDA extensions which can be built during installation or runtime.
This version of fused Adam implements 2 fusions.

View File

@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires ColossalAI to be installed via
``pip install .``.
`FusedLAMB` requires CUDA extensions which can be built during installation or runtime.
This version of fused LAMB implements 2 fusions.

View File

@ -10,8 +10,7 @@ from colossalai.utils import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only. Requires ColossalAI to be installed via
``pip install .``.
`FusedSGD` requires CUDA extensions which can be built during installation or runtime.
This version of fused SGD implements 2 fusions.

View File

@ -19,7 +19,7 @@ class HybridAdam(NVMeOptimizer):
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``
`HybriadAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.

View File

@ -1,4 +1,5 @@
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple
@ -12,7 +13,7 @@ from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@ -78,8 +79,16 @@ class ZeroOptimizer(ColossalaiOptimizer):
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, module.fp32_params):
ddp_param_list = []
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!")
else:
ddp_param_list.append(param)
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
@ -140,6 +149,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self._found_overflow.item() > 0
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
@ -201,6 +214,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
@ -285,6 +299,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
fake_params_list = list()
for param in group['params']:
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:

View File

@ -12,12 +12,14 @@ from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer
from .utils import get_static_torch_model
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@ -80,7 +82,7 @@ class ColoDDP(torch.nn.Module):
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
@ -115,7 +117,7 @@ class ColoDDP(torch.nn.Module):
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
@ -199,14 +201,18 @@ class ZeroDDP(ColoDDP):
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False) -> None:
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
@ -231,8 +237,11 @@ class ZeroDDP(ColoDDP):
for p in param_order.generate():
assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False):
p.data = p.data.half()
if strict_ddp_mode and not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
fp32_data = p.data.float()
@ -251,10 +260,11 @@ class ZeroDDP(ColoDDP):
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
@ -266,19 +276,42 @@ class ZeroDDP(ColoDDP):
self._logger = get_dist_logger()
def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
p.grad = None
@ -331,12 +364,10 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
"""
Args:
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
Returns:
dict:
@ -346,7 +377,30 @@ class ZeroDDP(ColoDDP):
>>> module.state_dict().keys()
['bias', 'weight']
"""
if strict:
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
@ -405,8 +459,14 @@ class ZeroDDP(ColoDDP):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
ddp_param_list.append((name, param))
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
@ -542,8 +602,16 @@ class ZeroDDP(ColoDDP):
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
else:
ddp_param_list.append((name, param))
fp32_to_name = dict()
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
fp32_to_name[fp32_p] = name

View File

@ -17,6 +17,7 @@ class GeminiDDP(ZeroDDP):
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None,
@ -54,4 +55,4 @@ class GeminiDDP(ZeroDDP):
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -47,30 +47,29 @@ def _get_shallow_copy_model(model: nn.Module):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
name_to_module = dict()
old_to_new = dict()
for name, module in _get_dfs_module_list(model):
new_module = copy(module)
new_module._modules = OrderedDict()
for subname, submodule in module._modules.items():
if submodule is None:
continue
full_name = name + ('.' if name else '') + subname
setattr(new_module, subname, name_to_module[full_name])
name_to_module[name] = new_module
return name_to_module['']
setattr(new_module, subname, old_to_new[submodule])
old_to_new[module] = new_module
return old_to_new[model]
def get_static_torch_model(gemini_ddp_model,
def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
"""Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original ZeroDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model
zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
@ -78,11 +77,11 @@ def get_static_torch_model(gemini_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.nn.parallel import GeminiDDP
assert isinstance(gemini_ddp_model, GeminiDDP)
from colossalai.nn.parallel import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = gemini_ddp_model.module
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0:

View File

@ -211,7 +211,7 @@ class WorkerBase(ABC):
refcount = 0
with self.output_list_condition_lock:
if refcount < lifecycle:
if refcount <= lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
@ -390,7 +390,7 @@ class WorkerBase(ABC):
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank)
producer_output_key, rank=self.pp_rank, offsets=offsets)
else:
for i in range(producer_num):

View File

@ -29,9 +29,6 @@ class FillDrainWorker(WorkerBase):
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key

View File

@ -1,22 +1,46 @@
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param, disposable)
from .common import (
clip_grad_norm_fp32,
conditional_context,
copy_tensor_parallel_attributes,
count_zeros_fp32,
disposable,
ensure_path_exists,
free_port,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
is_no_pp_or_last_stage,
is_tp_rank_0,
is_using_ddp,
is_using_pp,
is_using_sequence,
multi_tensor_applier,
param_is_not_tensor_parallel_duplicate,
print_rank_0,
switch_virtual_pipeline_parallel_rank,
sync_model_param,
)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction,
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .memory import (
colo_device_memory_capacity,
colo_device_memory_used,
colo_get_cpu_memory_capacity,
colo_set_cpu_memory_capacity,
colo_set_process_memory_fraction,
report_memory_usage,
)
from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer
__all__ = [
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
'is_ddp_ignored',
'is_dp_rank_0',
'is_tp_rank_0',
'is_no_pp_or_last_stage',

View File

@ -126,14 +126,18 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def is_ddp_ignored(p):
return getattr(p, '_ddp_to_ignore', False)
def _calc_l2_norm(grads):
# we should not
# we should not
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])

View File

@ -0,0 +1,440 @@
import contextlib
import copy
import gc
import pprint
from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.profiler import MetaTensor
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_TorchFactoryMethod = [
"arange",
"empty",
"eye",
"full",
"linspace",
"logspace",
"ones",
"rand",
"randn",
"randint",
"randperm",
"zeros",
"tensor",
]
orig_empty = torch.empty # avoid override
scm = ShapeConsistencyManager()
class LazyTensor(torch.Tensor):
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
Usage:
1. Use ``LazyTensor`` instead of ``torch.Tensor``.
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x += 1
>>> y = x * x
>>> y = y.cuda().half()
>>> y[0, 0] = 0
>>> y = y.materialize() # materialize the tensor
>>> print(y)
tensor([[0., 1., 1.],
[1., 1., 1.]], device='cuda:0', dtype=torch.float16)
2. Generate ``MetaTensor`` from ``LazyTensor``
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x.reshape(3, 2)
>>> x = x.traceable() # generate ``MetaTensor``
>>> print(x)
MetaTensor(..., size=(3, 2), device=cpu, dtype=torch.float32)
3. Use ``LazyTensor`` to generate sharded ``nn.Parameter``.
>>> x = LazyTensor(torch.zeros, 2, 3)
>>> x.spec = ... # some ``ShardingSpec``
>>> x.distribute() # distribute the tensor according to the ``ShardingSpec``
Warnings:
1. Cases that ``LazyTensor`` can't deal with.
>>> x = LazyTensor(torch.ones, 2, 3)
>>> x[0, 0] = -x[0, 0] # this will cause infinite recursion
2. ``LazyTensor.materialize()`` can't be called multiple times.
>>> x = LazyTensor(torch.ones, 2, 3)
>>> x.materialize()
>>> x.materialize() # this is disallowed
"""
_repr = True
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_cached_data: Optional[torch.Tensor] = None # materialized data
@staticmethod
def __new__(cls, func, *args, dtype=None, device=None, **kwargs):
elem = func(*args, dtype=dtype, device='meta', **kwargs)
r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=device if device is not None else torch.device('cpu'),
requires_grad=elem.requires_grad)
r._meta_data = MetaTensor(elem, fake_device=device)
return r
def __init__(self, func, *args, dtype=None, device=None, **kwargs):
self._factory_method = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs)
self._cached_buffer = list() # (func, args, kwargs)
self._spec = None
self._data = self
def __repr__(self):
if self._repr:
# avoid recursive representation
self.__class__._repr = False
s = f'LazyTensor(..., size={tuple(self._meta_data.shape)}, device={self._meta_data.device}, dtype={self._meta_data.dtype})\n'\
f'factory method: {self._factory_method}\n'\
f'cached: {pprint.pformat(self._cached_buffer) if self._cached_data is None else self._cached_data}\n'\
f'spec: {self._spec}'
self.__class__._repr = True
return s
else:
return 'LazyTensor(...)'
def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
Warnings:
Calling ``self.materialize()`` will clear all cached sequence and factory method,
because we don't allow materialize the same ``LazyTensor`` twice.
This is mentioned in the paper: https://arxiv.org/pdf/2102.13267.pdf (Part 4.3).
Returns:
torch.Tensor: The materialized tensor.
"""
target = self._data._realize_cached_data()
if isinstance(self, nn.Parameter):
target = nn.Parameter(target, requires_grad=self.requires_grad)
self._clear_all()
return target
def traceable(self) -> MetaTensor:
"""Generate ``MetaTensor`` from ``LazyTensor``. (Mostly for tracing)
Returns:
MetaTensor: The generated ``MetaTensor``.
"""
if isinstance(self, nn.Parameter):
return nn.Parameter(self._meta_data, requires_grad=self.requires_grad)
else:
return self._meta_data
def distribute(self) -> torch.Tensor:
"""Distribute the ``LazyTensor`` according to the ``ShardingSpec``.
Returns:
torch.Tensor: The sharded tensor.
"""
if self._spec is None:
raise RuntimeError('ShardingSpec is not set for\n{self}')
spec, device_mesh = self._spec, self._spec.device_mesh
target = self.materialize()
# TODO(some man): better not be coupled with auto-parallel
target.data = scm.apply_for_autoparallel_runtime(target.data, ShardingSpec(device_mesh, target.shape, {}),
spec).detach().clone()
return target
def _realize_cached_data(self) -> torch.Tensor:
# self._cached_data should be generated after the first call of this function
if self._cached_data is None:
if self._factory_method is not None:
# apply factory method
func, args, kwargs = self._factory_method
# apply cached sequence
self._cached_data = self._apply_cache_buffer(func(*args, **kwargs))
else:
# apply cached sequence only
self._cached_data = self._apply_cache_buffer()
return self._cached_data
def _apply_cache_buffer(self, target=None) -> torch.Tensor:
# dump all cached sequence
# super-dainiu: support methods for single Tensor only
def replace(x):
if x is self:
return target
elif isinstance(x, LazyTensor):
return x._realize_cached_data()
return x
packed = None
for (func, args, kwargs) in self._cached_buffer:
if func == torch.Tensor.requires_grad_:
packed = func, args, kwargs # requires grad should be set at last
else:
o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
# super-dainiu: set requires_grad after all inplace-ops are done
if packed is not None:
func, args, kwargs = packed
func(*tree_map(replace, args), **tree_map(replace, kwargs))
return target
# clear all means:
# 1. clear factory method
# 2. clear cached sequence
# 3. clear cached data
def _clear_all(self):
self._cached_data = None
self._cached_buffer = None
self._data = None
gc.collect() # avoid memory leak
# cache everything with __torch_function__
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
target = None
if isinstance(func, torch._C.ScriptMethod):
def unwrap(x):
if isinstance(x, LazyTensor):
return x._meta_data
return x
target: LazyTensor = args[0].clone()
target._cached_buffer.append((func, args, kwargs))
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
**tree_map(unwrap, kwargs))
else:
def unwrap(x):
nonlocal target
if isinstance(x, LazyTensor):
target = x if (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__") else x.clone()
target._cached_buffer.append((func, args, kwargs))
return x._meta_data
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
o = func(*args, **kwargs)
if isinstance(o, MetaTensor):
target._meta_data = o
return target
else:
return o
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass # skip
def clone(self) -> "LazyTensor":
"""Create a new ``LazyTensor`` with same cached sequence and factory method.
Returns:
LazyTensor: the new ``LazyTensor``
"""
target = LazyTensor(orig_empty, 0, dtype=self._meta_data.dtype, device=self._meta_data.device)
target._factory_method = None
target._cached_buffer = list()
target._meta_data = self._meta_data.clone()
target._cached_data = self._cached_data.clone() if self._cached_data is not None else None
target._spec = copy.deepcopy(self._spec)
return target
def detach(self) -> "LazyTensor":
target = self.clone()
target._cached_buffer.append((torch.Tensor.detach_, (self,), {}))
return target
@property
def spec(self) -> ShardingSpec:
return self._spec
@spec.setter
def spec(self, other: ShardingSpec):
self._spec = other
@property
def data(self) -> "LazyTensor":
return self._data.detach()
@data.setter
def data(self, other: "LazyTensor") -> "LazyTensor":
"""This avoid the following infinite recursion, which is very common in ``nn.Module`` initialization.
Usage:
>>> a = LazyTensor(torch.empty, 0, dtype=torch.float32, device='cpu')
>>> b = a.cuda()
>>> a.data = b
"""
self._data = other
class LazyInitContext():
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
Usage:
1. The model is initialized, but no real memory is allocated.
>>> ctx = LazyInitContext()
>>> with ctx:
>>> model = MyModel().cuda()
2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated.
>>> with ctx.traceable(model):
>>> gm = symbolic_trace(model, meta_args=meta_args)
>>> # Solve the execution strategy and apply the strategy to the model
>>> strategy = StrategyAndSpec()
3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device)
>>> model = ctx.materialize(model)
3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario)
>>> model = apply_strategy_to_all_params(model, strategy)
>>> model = ctx.distribute(model)
Warnings:
This API is still experimental and further modifications can be made to it.
For example:
1. Quantization strategies can be applied before allocating real memory.
2. Lazy initialization seems slower than normal initialization.
"""
def __init__(self):
self.overrides = {}
def __enter__(self):
def wrap_factory_method(target):
# factory functions (eg. torch.empty())
def wrapper(*args, **kwargs):
return LazyTensor(target, *args, **kwargs)
return wrapper, target
def wrap_factory_like_method(orig_target, target):
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return LazyTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
return wrapper, target
self.overrides = {
target: wrap_factory_method(getattr(torch, target))
for target in _TorchFactoryMethod
if callable(getattr(torch, target, None))
}
self.overrides.update({
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
for target in _TorchFactoryMethod
if callable(getattr(torch, target + '_like', None))
})
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb):
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)
@staticmethod
def materialize(module: torch.nn.Module):
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
Args:
module (torch.nn.Module): Target ``nn.Module``
"""
@torch.no_grad()
def init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
init_recursively(mod)
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
setattr(module, name, param.materialize())
for name, buf in module.named_buffers(recurse=False):
setattr(module, name, buf.materialize())
init_recursively(module)
return module
@staticmethod
def distribute(module: torch.nn.Module):
"""Initialize and shard all ``nn.Parameter`` from ``LazyTensor``.
Args:
module (torch.nn.Module): Sharded target ``nn.Module``
"""
@torch.no_grad()
def init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
init_recursively(mod)
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
setattr(module, name, param.distribute())
for name, buf in module.named_buffers(recurse=False):
setattr(module, name, buf.distribute())
init_recursively(module)
return module
@staticmethod
@contextlib.contextmanager
def traceable(module: torch.nn.Module):
"""Initialize all ``nn.Parameters`` as ``MetaTensor``. This enables ``ColoTracer`` with control flow.
Args:
module (torch.nn.Module): Traceable ``nn.Module`` with ``MetaTensor`` as parameters.
"""
orig_val = dict()
def init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
init_recursively(mod)
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
setattr(module, name, param.traceable())
orig_val[(module, name)] = param
for name, buf in module.named_buffers(recurse=False):
setattr(module, name, buf.traceable())
orig_val[(module, name)] = buf
init_recursively(module)
yield
# restore original values
for (module, name), val in orig_val.items():
setattr(module, name, val)

View File

@ -1,4 +1,5 @@
import math
from typing import Optional
import torch
import torch.distributed as dist
@ -7,6 +8,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
from colossalai.utils import is_model_parallel_parameter
@ -101,7 +103,11 @@ def split_half_float_double(tensor_list):
return buckets
def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
def reduce_tensor_dp_group(tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None):
"""
Reduce the tensor in the data parallel process group
@ -114,7 +120,7 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
:type tensor: torch.Tensor
:type dtype: torch.dtype, optional
:type dst_rank: int, optional
:type parallel_mode: ParallelMode, optional
:type pg: ProcessGroup, optional
"""
# use the original dtype
if dtype is None:
@ -126,25 +132,22 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
else:
tensor_to_reduce = tensor
world_size = gpc.get_world_size(parallel_mode)
group = gpc.get_group(parallel_mode)
world_size = dist.get_world_size(group=group)
tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used
# else, reduce is used
use_all_reduce = dst_rank is None
use_all_reduce = dst_local_rank is None
if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
else:
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
local_rank = gpc.get_local_rank(parallel_mode)
if use_all_reduce or dst_rank == local_rank:
local_rank = dist.get_rank(group=group)
if use_all_reduce or dst_local_rank == local_rank:
tensor.copy_(tensor_to_reduce)
return tensor

View File

@ -1,12 +1,12 @@
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
import torch.distributed as dist
from torch.distributed import ProcessGroup
class BaseStore:
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
@property
def world_size(self):

View File

@ -1,14 +1,12 @@
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from torch.distributed import ProcessGroup
from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, dp_parallel_mode):
super().__init__(dp_parallel_mode)
self._grads = dict()
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._params = dict()
self._num_elements_in_bucket = dict()
@ -20,25 +18,24 @@ class BucketStore(BaseStore):
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
def reset(self):
keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank]
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]

View File

@ -1,14 +1,15 @@
from typing import List
from torch import Tensor
from torch.distributed import ProcessGroup
from .base_store import BaseStore
class ParameterStore(BaseStore):
def __init__(self, dp_paralle_mode):
super().__init__(dp_paralle_mode)
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()

View File

@ -1,5 +1,5 @@
from functools import partial
from itertools import groupby
from typing import Optional
import torch
import torch.distributed as dist
@ -10,6 +10,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
from ._utils import (
@ -18,7 +19,7 @@ from ._utils import (
flatten,
get_grad_accumulate_object,
has_inf_or_nan,
reduce_tensor,
reduce_tensor_dp_group,
release_param_grad,
split_half_float_double,
sync_param,
@ -33,35 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def __init__(
self,
optimizer: Optimizer,
# grad scaler config
initial_scale=2**16,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=2000,
hysteresis=2,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
# grad clipping
clip_grad_norm=0.0,
verbose=False,
# communication
reduce_bucket_size=1024 * 1024,
communication_dtype=None,
overlap_communication=False,
# stage 2
partition_grad=False,
dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL,
# cpu offload
cpu_offload=False,
# forced dtype
forced_dtype=None):
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None):
# TODO: add support for
# 1. fp16 master weights
@ -76,21 +63,32 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# stage 2
self._partition_grads = partition_grad
# cpu_offload
self._cpu_offload = cpu_offload
# get process groups
self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
self._world_size = colo_pg.dp_world_size()
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
self._dp_torch_group = colo_pg.dp_process_group()
self._mp_torch_group = None
if colo_pg.tp_world_size() > 1:
self._mp_torch_group = colo_pg.tp_process_group()
elif colo_pg is None:
dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL
self._dp_group = gpc.get_group(dp_parallel_mode)
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_group = gpc.get_group(mp_parallel_mode)
self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
self._mp_torch_group = None
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
self._mp_group = None
raise NotImplementedError
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
@ -126,9 +124,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode)
self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group)
self._bucket_store = BucketStore(self._dp_torch_group)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -209,6 +207,30 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def num_param_groups(self):
return len(self._fp16_param_groups)
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _search_colo_process_group(self):
colo_flag = False
colo_pg = None
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
if isinstance(param, ColoParameter):
colo_flag = True
if colo_pg is None:
colo_pg = param.get_process_group()
else:
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
elif colo_flag:
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg
def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
@ -223,22 +245,16 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
numel_per_rank[rank_to_go] += param.numel()
if self._verbose:
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
ranks=[0],
parallel_mode=self._dp_parallel_mode)
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
###########################
# Backward Reduction Hook #
###########################
###########################################################
# Backward Reduction Hook
###########################################################
def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
return grad
def _attach_reduction_hook(self):
# we iterate over the fp16 params
@ -256,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
else:
reduce_rank = None
def _define_and_attach(param, reduce_rank):
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
param=param,
reduce_rank=reduce_rank)
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args):
reduction_func()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
accum_grad_obj.register_hook(reduce_grad_hook)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
_define_and_attach(param, reduce_rank)
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
if param_bucket.is_full_or_oversized():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
if not param_bucket.is_empty():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
bucket_size=bucket_size,
reduce_rank=reduce_rank)
def _reduce_grads_in_bucket(self, reduce_rank=None):
#######################
# Reduction Functions #
#######################
def _run_reduction(self, reduce_rank=None):
# reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping
# communication with computation
@ -339,46 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
def _add_to_reduction_bucket(self, param, reduce_rank=None):
param_size = param.numel()
for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._run_reduction(reduce_rank)
##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(tensor=flat,
dtype=self._communication_dtype,
dst_rank=reduce_rank,
parallel_mode=self._dp_parallel_mode)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
################################
# torch.optim.Optimizer methods
@ -443,8 +445,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_group,
mp_group=self._mp_group)
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)
# create flat gradient for the flat fp32 params
@ -482,9 +484,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
for index in range(self._world_size):
rank = self._dp_global_ranks[index]
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)
for handle in handles:
@ -506,11 +509,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
# all-reduce over model parallel group
if self._mp_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
if self._found_overflow.item() > 0:
return True
@ -569,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param)
self._add_to_reduction_bucket(param)
# we need to reduce the gradients
# left in the communication bucket
self._reduce_grads_in_bucket()
self._run_reduction()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
@ -581,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank)
self._run_reduction(reduce_rank)

View File

@ -8,6 +8,7 @@ import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
class TrainingPhase(Enum):
@ -24,7 +25,7 @@ class GeminiZeROHook(ColoParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
@ -37,7 +38,7 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.record_model_data_volume()
def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
params = [p for p in params if not is_ddp_ignored(p)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)

View File

@ -1,17 +1,18 @@
FROM hpcaitech/cuda-conda:11.3
# install torch
RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
# install apex
RUN git clone https://github.com/NVIDIA/apex && \
cd apex && \
pip install packaging && \
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
# install colossalai
RUN git clone https://github.com/hpcaitech/ColossalAI.git \
&& cd ./ColossalAI \
&& pip install -v --no-cache-dir .
&& CUDA_EXT=1 pip install -v --no-cache-dir .
# install titans
RUN pip install --no-cache-dir titans

View File

@ -1,28 +1,40 @@
## Examples folder document
# Colossal-AI Examples
## Table of Contents
<ul>
<li><a href="#Example-folder-description">Example folder description</a> </li>
<li><a href="#Integrate-Your-Example-With-System-Testing">Integrate Your Example With System Testing</a> </li>
</ul>
## Example folder description
- [Colossal-AI Examples](#colossal-ai-examples)
- [Table of Contents](#table-of-contents)
- [Overview](#overview)
- [Folder Structure](#folder-structure)
- [Integrate Your Example With Testing](#integrate-your-example-with-testing)
This folder provides several examples using colossalai. The images folder includes model like diffusion, dreambooth and vit. The language folder includes gpt, opt, palm and roberta. The tutorial folder is for concept illustration, such as auto-parallel, hybrid-parallel and so on.
## Overview
This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications.
## Integrate Your Example With System Testing
## Folder Structure
For example code contributor, to meet the expectation and test your code automatically using github workflow function, here are several steps:
```text
└─ examples
└─ images
└─ vit
└─ test_ci.sh
└─ train.py
└─ README.md
└─ ...
└─ ...
```
## Integrate Your Example With Testing
- (must) Have a test_ci.sh file in the folder like shown below in 'File Structure Chart'
- The dataset should be located in the company's machine and can be announced using environment variable and thus no need for a separate terminal command.
- The model parameters should be small to allow fast testing.
- File Structure Chart
Regular checks are important to ensure that all examples run without apparent bugs and stay compatible with the latest API.
Colossal-AI runs workflows to check for examples on a on-pull-request and weekly basis.
When a new example is added or changed, the workflow will run the example to test whether it can run.
Moreover, Colossal-AI will run testing for examples every week.
└─examples
└─images
└─vit
└─requirements.txt
└─test_ci.sh
Therefore, it is essential for the example contributors to know how to integrate your example with the testing workflow. Simply, you can follow the steps below.
1. Create a script called `test_ci.sh` in your example folder
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
4. Implement the logic such as dependency setup and example execution

View File

@ -26,6 +26,16 @@ Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1]
More details can be found in our [blog of Stable Diffusion v1](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) and [blog of Stable Diffusion v2](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0).
## Roadmap
This project is in rapid development.
- [X] Train a stable diffusion model v1/v2 from scatch
- [X] Finetune a pretrained Stable diffusion v1 model
- [X] Inference a pretrained model using PyTorch
- [ ] Finetune a pretrained Stable diffusion v2 model
- [ ] Inference a pretrained model using TensoRT
## Installation
### Option #1: install from source
@ -123,7 +133,7 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
### stable-diffusion-v1-5 from runway
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
```
git lfs install
@ -156,7 +166,7 @@ You can change the trainging config in the yaml file
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
## Finetune Example
## Finetune Example (Work In Progress)
### Training on Teyvat Datasets
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.

View File

@ -153,7 +153,8 @@ def parse_args(input_args=None):
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
help=
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
)
parser.add_argument(
"--gradient_checkpointing",
@ -355,10 +356,14 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def main(args):
colossalai.launch_from_torch(config={})
if args.seed is not None:
gpc.set_seed(args.seed)
if args.seed is None:
colossalai.launch_from_torch(config={})
else:
colossalai.launch_from_torch(config={}, seed=args.seed)
local_rank = gpc.get_local_rank(ParallelMode.DATA)
world_size = gpc.get_world_size(ParallelMode.DATA)
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
@ -387,7 +392,7 @@ def main(args):
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
disable=not local_rank == 0,
):
images = pipeline(example["prompt"]).images
@ -399,7 +404,7 @@ def main(args):
del pipeline
# Handle the repository creation
if gpc.get_local_rank(ParallelMode.DATA) == 0:
if local_rank == 0:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
@ -464,8 +469,9 @@ def main(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
unet = gemini_zero_dpp(unet, args.placement)
@ -554,7 +560,7 @@ def main(args):
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
@ -566,7 +572,7 @@ def main(args):
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0)
progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
progress_bar.set_description("Steps")
global_step = 0
@ -643,7 +649,7 @@ def main(args):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=torch_unet,
@ -658,7 +664,7 @@ def main(args):
torch.cuda.synchronize()
unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,

View File

@ -0,0 +1,32 @@
from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 8
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 3
WARMUP_EPOCHS = 1
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 32
DEPTH = 2
NUM_HEADS = 4
MLP_RATIO = 4
NUM_CLASSES = 10
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
USE_DDP = True
TP_WORLD_SIZE = 2
TP_TYPE = 'row'
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
gradient_accumulation = 2
LOG_PATH = "./log_ci"

View File

@ -1,2 +1,8 @@
colossalai >= 0.1.12
torch >= 1.8.1
numpy>=1.24.1
timm>=0.6.12
titans>=0.0.7
tqdm>=4.61.2
transformers>=4.25.1
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist

View File

@ -0,0 +1,9 @@
export OMP_NUM_THREADS=4
pip install -r requirements.txt
# train
colossalai run \
--nproc_per_node 4 train.py \
--config configs/vit_1d_tp2_ci.py \
--dummy_data

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
from timm.models.vision_transformer import _create_vision_transformer
from titans.dataloader.imagenet import build_dali_imagenet
from tqdm import tqdm
from vit import DummyDataLoader
import colossalai
from colossalai.core import global_context as gpc
@ -56,8 +57,8 @@ def init_spec_func(model, tp_type):
def train_imagenet():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=True, action='store_true')
parser.add_argument('--resume_from', default=False)
parser.add_argument('--resume_from', default=False, action='store_true')
parser.add_argument('--dummy_data', default=False, action='store_true')
args = parser.parse_args()
colossalai.launch_from_torch(config=args.config)
@ -74,10 +75,22 @@ def train_imagenet():
logger.log_to_file(log_path)
logger.info('Build data loader', ranks=[0])
root = os.environ['DATA']
train_dataloader, test_dataloader = build_dali_imagenet(root,
train_batch_size=gpc.config.BATCH_SIZE,
test_batch_size=gpc.config.BATCH_SIZE)
if not args.dummy_data:
root = os.environ['DATA']
train_dataloader, test_dataloader = build_dali_imagenet(root,
train_batch_size=gpc.config.BATCH_SIZE,
test_batch_size=gpc.config.BATCH_SIZE)
else:
train_dataloader = DummyDataLoader(length=10,
batch_size=gpc.config.BATCH_SIZE,
category=gpc.config.NUM_CLASSES,
image_size=gpc.config.IMG_SIZE,
return_dict=False)
test_dataloader = DummyDataLoader(length=5,
batch_size=gpc.config.BATCH_SIZE,
category=gpc.config.NUM_CLASSES,
image_size=gpc.config.IMG_SIZE,
return_dict=False)
logger.info('Build model', ranks=[0])

View File

@ -32,21 +32,24 @@ class DummyDataGenerator(ABC):
class DummyDataLoader(DummyDataGenerator):
batch_size = 4
channel = 3
category = 8
image_size = 224
def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
super().__init__(length)
self.batch_size = batch_size
self.channel = channel
self.category = category
self.image_size = image_size
self.return_dict = return_dict
def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
DummyDataLoader.channel,
DummyDataLoader.image_size,
DummyDataLoader.image_size,
device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
image_dict['pixel_values'] = torch.rand(
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
dtype=torch.int64,
device=get_current_device())
if not self.return_dict:
return image_dict['pixel_values'], image_dict['label']
return image_dict

View File

@ -39,9 +39,15 @@ If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-
For simplicity, the input data is randonly generated here.
## Training
We provide two solutions. One utilizes the hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism.
The other one uses Pipeline Parallelism Only.
In the future, we are going merge them together and they can be used orthogonally to each other.
We provide two stable solutions.
One utilizes the Gemini to implement hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism for a huggingface GPT model.
The other one use [Titans](https://github.com/hpcaitech/Titans), a distributed executed model zoo maintained by ColossalAI,to implement the hybrid parallel strategies of TP + ZeRO + PP.
We recommend using Gemini to qucikly run your model in a distributed manner.
It doesn't require significant changes to the model structures, therefore you can apply it on a new model easily.
And use Titans as an advanced weapon to pursue a more extreme performance.
Titans has included the some typical models, such as Vit and GPT.
However, it requires some efforts to start if facing a new model structure.
### GeminiDPP/ZeRO + Tensor Parallelism
```bash
@ -56,6 +62,11 @@ The `train_gpt_demo.py` provides three distributed plans, you can choose the pla
- Pytorch DDP
- Pytorch ZeRO
### Titans (Tensor Parallelism) + ZeRO + Pipeline Parallelism
Titans provides a customized GPT model, which uses distributed operators as building blocks.
In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
You can switch parallel strategies using a config file.
## Performance

View File

@ -16,14 +16,14 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
BATCH_SIZE = 8
SEQ_LENGTH = 128
HIDDEN_DIM = 3072
BATCH_SIZE = 16
SEQ_LENGTH = 1024
HIDDEN_DIM = 4096
NUM_HEADS = 16
NUM_LAYERS = 1
NUM_LAYERS = 4
VOCAB_SIZE = 50257
NUM_STEPS = 10
FP16 = False
FP16 = True
def get_cpu_mem():
@ -40,7 +40,7 @@ def get_mem_info(prefix=''):
def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8
# Randomly Generated Data
@ -66,13 +66,7 @@ def main():
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
# Both device mesh initialization and model initialization will be integrated into autoparallelize
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# Enable auto-parallel
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
# print solution on rank 0
if gpc.get_global_rank() == 0:

View File

@ -0,0 +1,2 @@
colossalai >= 0.1.12
torch >= 1.8.1

View File

@ -120,7 +120,7 @@ def run_master(args):
logger.info(f'{rank=} numel in the partition:{numel}')
# build optim
pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
ranks_tflops = {}
for n in range(NUM_STEPS):

View File

@ -1,18 +1,20 @@
for MODEL_TYPE in "gpt2_medium"; do
for BATCH_SIZE in 16; do
for GPUNUM in 1 2 4 8; do
for TPDEGREE in 1 2 4 8; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
for PLACEMENT in "cpu" "auto"; do
echo "****************** Begin ***************************"
echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}"
MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./gemini/run_gemini.sh
echo "****************** Finished ***************************"
echo ""
echo ""
for DISTPLAN in "colossalai"; do
for BATCH_SIZE in 16; do
for GPUNUM in 1 2 4 8; do
for TPDEGREE in 1 2 4 8; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
for PLACEMENT in "cpu" "auto"; do
echo "****************** Begin ***************************"
echo "+ benchmrking MODEL ${MODEL_TYPE} DISTPLAN ${DISTPLAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}"
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./run_gemini.sh
echo "****************** Finished ***************************"
echo ""
echo ""
done
done
done
done

View File

@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_30b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_40b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
def model_builder(model_size: str) -> callable:
if model_size == "gpt2_medium":
return gpt2_medium
@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable:
return gpt2_20b
elif model_size == "gpt2_24b":
return gpt2_24b
elif model_size == "gpt2_30b":
return gpt2_30b
elif model_size == "gpt2_40b":
return gpt2_40b
else:
raise TypeError(f"model_builder {model_size}")

View File

@ -0,0 +1,2 @@
colossalai >= 0.1.12
torch >= 1.8.1

View File

@ -1,15 +1,15 @@
set -x
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
export DISTPAN=${DISTPAN:-"colossalai"}
export DISTPLAN=${DISTPLAN:-"colossalai"}
# The following options only valid when DISTPAN="colossalai"
# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH
mkdir -p gemini_logs
@ -20,5 +20,6 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
--shardinit=${USE_SHARD_INIT} \
--distplan=${DISTPAN} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
--distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log

View File

@ -0,0 +1,35 @@
set -x
$(cd `dirname $0`;pwd)
export TRAIN_STEP=4
for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "colossalai"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
for PLACEMENT in "cpu" "auto"; do
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./run_gemini.sh
done
done
done
done
done
for DISTPLAN in "zero1" "zero2"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
bash ./run_gemini.sh
done
done
done
done
done

View File

@ -65,6 +65,13 @@ def parse_args():
default="gpt2_medium",
help="model model scale",
)
parser.add_argument(
"--train_step",
type=int,
default=10,
help="training iterations for test",
)
args = parser.parse_args()
return args
@ -180,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# Gemini + ZeRO DDP
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
fp16_init_scale = 2**5
gpu_margin_mem_ratio_for_auto = 0
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model,
strict_ddp_mode=ddp_flag,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=64)
search_range_mb=128)
# configure the const policy
if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
@ -236,7 +244,8 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
NUM_STEPS = args.train_step
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
@ -270,14 +279,17 @@ def main():
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
tensor_parallelize(model, tp_pg)
# You should notice that v0.1.10 is not compatible with TP degree > 1
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
model, optimizer = build_gemini(model, tp_pg, args.placement)
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else:
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda()
if args.distplan.startswith("torch"):
@ -288,12 +300,17 @@ def main():
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
partition_flag = args.distplan == "zero2"
model = model.half()
partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(optimizer,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True)
optimizer = LowLevelZeroOptimizer(
optimizer,
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True,
)
# model is shared after TP
numel = get_model_size(model)

Some files were not shown because too many files have changed in this diff Show More