mirror of https://github.com/hpcaitech/ColossalAI
Update metainfo patch branch (#2517)
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
dev/gpt2_metainfo_patch
parent
ce08661eb1
commit
7a58dc5ad2
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"build": [
|
||||
{
|
||||
"torch_version": "1.11.0",
|
||||
"cuda_image": "hpcaitech/cuda-conda:10.2"
|
||||
},
|
||||
{
|
||||
"torch_version": "1.11.0",
|
||||
"cuda_image": "hpcaitech/cuda-conda:11.3"
|
||||
},
|
||||
{
|
||||
"torch_version": "1.12.1",
|
||||
"cuda_image": "hpcaitech/cuda-conda:10.2"
|
||||
},
|
||||
{
|
||||
"torch_version": "1.12.1",
|
||||
"cuda_image": "hpcaitech/cuda-conda:11.3"
|
||||
},
|
||||
{
|
||||
"torch_version": "1.12.1",
|
||||
"cuda_image": "hpcaitech/cuda-conda:11.6"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
1.12.0-11.3.0
|
||||
1.11.0-11.3.0
|
||||
1.10.1-11.3.0
|
|
@ -0,0 +1,149 @@
|
|||
# CI/CD
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [CI/CD](#cicd)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Overview](#overview)
|
||||
- [Workflows](#workflows)
|
||||
- [Checks on Pull Requests](#checks-on-pull-requests)
|
||||
- [Regular Checks](#regular-checks)
|
||||
- [Release](#release)
|
||||
- [Manual Dispatch](#manual-dispatch)
|
||||
- [Release bdist wheel](#release-bdist-wheel)
|
||||
- [Dispatch Example Test](#dispatch-example-test)
|
||||
- [Compatibility Test](#compatibility-test)
|
||||
- [User Friendliness](#user-friendliness)
|
||||
- [Configuration](#configuration)
|
||||
- [Progress Log](#progress-log)
|
||||
|
||||
## Overview
|
||||
|
||||
Automation makes our development more efficient as the machine automatically run the pre-defined tasks for the contributors.
|
||||
This saves a lot of manual work and allow the developer to fully focus on the features and bug fixes.
|
||||
In Colossal-AI, we use [GitHub Actions](https://github.com/features/actions) to automate a wide range of workflows to ensure the robustness of the software.
|
||||
In the section below, we will dive into the details of different workflows available.
|
||||
|
||||
## Workflows
|
||||
|
||||
### Checks on Pull Requests
|
||||
|
||||
| Workflow Name | File name | Description |
|
||||
| --------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `Build` | `build.yml` | This workflow is triggered when the label `Run build and Test` is assigned to a PR. It will run all the unit tests in the repository with 4 GPUs. |
|
||||
| `Pre-commit` | `pre_commit.yml` | This workflow runs pre-commit checks for code style consistency. |
|
||||
| `Report pre-commit failure` | `report_precommit_failure.yml` | This PR will put up a comment in the PR to explain the precommit failure and remedy. This is executed when `Pre-commit` is done |
|
||||
| `Report test coverage` | `report_test_coverage.yml` | This PR will put up a comment to report the test coverage results. This is executed when `Build` is completed. |
|
||||
| `Test example` | `auto_example_check.yml` | The example will be automatically tested if its files are changed in the PR |
|
||||
|
||||
### Regular Checks
|
||||
|
||||
| Workflow Name | File name | Description |
|
||||
| ----------------------- | ----------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `Test example` | `auto_example_check.yml` | This workflow will test all examples every Sunday |
|
||||
| `Compatibility Test` | `auto_compatibility_test.yml` | This workflow will check the compatiblity of Colossal-AI against PyTorch and CUDA every Sunday. The PyTorch and CUDA versions are specified in `.compatibility`. |
|
||||
| `Build on 8 GPUs` | `build_gpu_8.yml` | This workflow will run the unit tests everyday with 8 GPUs. |
|
||||
| `Synchronize submodule` | `submodule.yml` | This workflow will check if any git submodule is updated. If so, it will create a PR to update the submodule pointers. |
|
||||
| `Close inactive issues` | `close_inactive.yml` | This workflow will close issues which are stale for 14 days. |
|
||||
|
||||
### Release
|
||||
|
||||
| Workflow Name | File name | Description |
|
||||
| --------------------------- | ------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `Draft GitHub Release Post` | `draft_github_release_post.yml` | Compose a GitHub release post draft based on the commit history. Triggered when the change of `version.txt` is merged. |
|
||||
| `Release to PyPI` | `release_pypi.yml` | Build and release the wheel to PyPI. Triggered when the change of `version.txt` is merged. |
|
||||
| `Release Nightly to PyPI` | `release_nightly.yml` | Build and release the nightly wheel to PyPI as `colossalai-nightly`. Automatically executed every Sunday. |
|
||||
| `Release Docker` | `release_docker.yml` | Build and release the Docker image to DockerHub. Triggered when the change of `version.txt` is merged. |
|
||||
| `Release bdist wheel` | `release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions. Manually dispatched. See more details in the next section. |
|
||||
| `Auto Release bdist wheel` | `auto_release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions.Triggered when the change of `version.txt` is merged. Build specificatons are stored in `.bdist.json` |
|
||||
| `Auto Compatibility Test` | `auto_compatibility_test.yml` | Check Colossal-AI's compatiblity against the PyTorch and CUDA version specified in `.compatibility`. Triggered when `version.txt` is changed in a PR. |
|
||||
|
||||
### Manual Dispatch
|
||||
|
||||
| Workflow Name | File name | Description |
|
||||
| ---------------------------- | -------------------------------- | ------------------------------------------------------ |
|
||||
| `Release bdist wheel` | `release_bdist.yml` | Build binary wheels with pre-built PyTorch extensions. |
|
||||
| `Dispatch Example Test` | `dispatch_example_check.yml` | Manually test a specified example. |
|
||||
| `Dispatch Compatiblity Test` | `dispatch_compatiblity_test.yml` | Test PyTorch and Python Compatibility. |
|
||||
|
||||
Refer to this [documentation](https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow) on how to manually trigger a workflow.
|
||||
I will provide the details of each workflow below.
|
||||
|
||||
#### Release bdist wheel
|
||||
|
||||
Parameters:
|
||||
- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels) which is regularly updated.
|
||||
- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda).
|
||||
- `ref`: input the branch or tag name to build the wheel for this ref.
|
||||
|
||||
#### Dispatch Example Test
|
||||
|
||||
parameters:
|
||||
- `example_directory`: the example directory to test. Multiple directories are supported and must be separated by comma. For example, language/gpt, images/vit. Simply input language or simply gpt does not work.
|
||||
|
||||
|
||||
#### Compatibility Test
|
||||
|
||||
Parameters:
|
||||
- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels).
|
||||
- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda).
|
||||
|
||||
> It only test the compatiblity of the main branch
|
||||
|
||||
|
||||
### User Friendliness
|
||||
|
||||
| Workflow Name | File name | Description |
|
||||
| ----------------- | ----------------------- | -------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `issue-translate` | `translate_comment.yml` | This workflow is triggered when a new issue comment is created. The comment will be translated into English if not written in English. |
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
This section lists the files used to configure the workflow.
|
||||
|
||||
1. `.compatibility`
|
||||
|
||||
This `.compatibility` file is to tell GitHub Actions which PyTorch and CUDA versions to test against. Each line in the file is in the format `${torch-version}-${cuda-version}`, which is a tag for Docker image. Thus, this tag must be present in the [docker registry](https://hub.docker.com/r/pytorch/conda-cuda) so as to perform the test.
|
||||
|
||||
2. `.bdist.json`
|
||||
|
||||
This file controls what pytorch/cuda compatible pre-built releases will be built and published. You can add a new entry according to the json schema below if there is a new wheel that needs to be built with AOT compilation of PyTorch extensions.
|
||||
|
||||
```json
|
||||
{
|
||||
"build": [
|
||||
{
|
||||
"torch_version": "",
|
||||
"cuda_image": ""
|
||||
},
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Progress Log
|
||||
|
||||
- [x] unit testing
|
||||
- [x] test on PR
|
||||
- [x] report test coverage
|
||||
- [x] regular test
|
||||
- [x] release
|
||||
- [x] official release
|
||||
- [x] nightly build
|
||||
- [x] binary build
|
||||
- [x] docker build
|
||||
- [x] draft release post
|
||||
- [x] pre-commit
|
||||
- [x] check on PR
|
||||
- [x] report failure
|
||||
- [x] example check
|
||||
- [x] check on PR
|
||||
- [x] regular check
|
||||
- [x] manual dispatch
|
||||
- [x] compatiblity check
|
||||
- [x] manual dispatch
|
||||
- [x] auto test when release
|
||||
- [x] helpers
|
||||
- [x] comment translation
|
||||
- [x] submodule update
|
||||
- [x] close inactive issue
|
|
@ -0,0 +1,74 @@
|
|||
name: Compatibility Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'version.txt'
|
||||
- '.compatibility'
|
||||
# run at 03:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
|
||||
schedule:
|
||||
- cron: '0 19 * * 6'
|
||||
|
||||
jobs:
|
||||
matrix_preparation:
|
||||
name: Prepare Container List
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- id: set-matrix
|
||||
run: |
|
||||
IFS=','
|
||||
DOCKER_IMAGE=()
|
||||
|
||||
while read tag; do
|
||||
DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tag}\"")
|
||||
done <.compatibility
|
||||
|
||||
container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" )
|
||||
container="[${container}]"
|
||||
echo "$container"
|
||||
echo "::set-output name=matrix::{\"container\":$(echo "$container")}"
|
||||
|
||||
build:
|
||||
name: Test for PyTorch Compatibility
|
||||
needs: matrix_preparation
|
||||
if: github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -U pip setuptools wheel --user
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
conda install cmake
|
||||
pip install -r requirements.txt
|
||||
pip install -v .
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -v --no-cache-dir .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
NCCL_SHM_DISABLE: 1
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
|
@ -0,0 +1,143 @@
|
|||
name: Test Example
|
||||
on:
|
||||
pull_request:
|
||||
# any change in the examples folder will trigger check for the corresponding example.
|
||||
paths:
|
||||
- 'examples/**'
|
||||
# run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
|
||||
schedule:
|
||||
- cron: '0 16 * * 6'
|
||||
|
||||
jobs:
|
||||
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
|
||||
detect-changed-example:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
||||
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
|
||||
name: Detect changed example files
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Locate base commit
|
||||
id: locate-base-sha
|
||||
run: |
|
||||
curBranch=$(git rev-parse --abbrev-ref HEAD)
|
||||
commonCommit=$(git merge-base origin/main $curBranch)
|
||||
echo $commonCommit
|
||||
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get all changed example files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v35
|
||||
with:
|
||||
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
|
||||
|
||||
- name: setup matrix
|
||||
id: setup-matrix
|
||||
run: |
|
||||
changedFileName=""
|
||||
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
|
||||
changedFileName="${file}:${changedFileName}"
|
||||
done
|
||||
echo "$changedFileName was changed"
|
||||
res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName`
|
||||
echo "All changed examples are $res"
|
||||
|
||||
if [ "$res" = "[]" ]; then
|
||||
echo "anyChanged=false" >> $GITHUB_OUTPUT
|
||||
echo "matrix=null" >> $GITHUB_OUTPUT
|
||||
else
|
||||
dirs=$( IFS=',' ; echo "${res[*]}" )
|
||||
echo "anyChanged=true" >> $GITHUB_OUTPUT
|
||||
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# If no file is changed, it will prompt an error and shows the matrix do not have value.
|
||||
check-changed-example:
|
||||
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
|
||||
needs.detect-changed-example.outputs.anyChanged == 'true'
|
||||
name: Test the changed example
|
||||
needs: detect-changed-example
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -v .
|
||||
|
||||
- name: Test the example
|
||||
run: |
|
||||
example_dir=${{ matrix.directory }}
|
||||
cd "${PWD}/examples/${example_dir}"
|
||||
bash test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
|
||||
# This is for all files' weekly check. Specifically, this job is to find all the directories.
|
||||
matrix_preparation:
|
||||
if: |
|
||||
github.repository == 'hpcaitech/ColossalAI' &&
|
||||
github.event_name == 'schedule'
|
||||
name: Prepare matrix for weekly check
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: setup matrix
|
||||
id: setup-matrix
|
||||
run: |
|
||||
res=`python .github/workflows/scripts/example_checks/check_example_weekly.py`
|
||||
all_loc=$( IFS=',' ; echo "${res[*]}" )
|
||||
echo "Found the examples: $all_loc"
|
||||
echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT
|
||||
|
||||
weekly_check:
|
||||
if: |
|
||||
github.repository == 'hpcaitech/ColossalAI' &&
|
||||
github.event_name == 'schedule'
|
||||
name: Weekly check all examples
|
||||
needs: matrix_preparation
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -v .
|
||||
|
||||
- name: Traverse all files
|
||||
run: |
|
||||
example_dir=${{ matrix.diretory }}
|
||||
echo "Testing ${example_dir} now"
|
||||
cd "${PWD}/examples/${example_dir}"
|
||||
bash test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
|
@ -0,0 +1,70 @@
|
|||
name: Auto Release bdist wheel
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'version.txt'
|
||||
types:
|
||||
- closed
|
||||
|
||||
jobs:
|
||||
matrix_preparation:
|
||||
name: Prepare Container List
|
||||
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- id: set-matrix
|
||||
run: |
|
||||
bdist=$(cat .bdist.json | tr '\n' ' ')
|
||||
echo "matrix=${bdist}" >> $GITHUB_OUTPUT
|
||||
|
||||
build:
|
||||
name: Release bdist wheels
|
||||
needs: matrix_preparation
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.build.cuda_image }}
|
||||
options: --gpus all --rm
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
# cub is for cuda 10.2
|
||||
- name: Copy scripts
|
||||
run: |
|
||||
cp -r ./.github/workflows/scripts/* ./
|
||||
|
||||
# link the cache diretories to current path
|
||||
ln -s /github/home/conda_pkgs ./conda_pkgs
|
||||
ln -s /github/home/pip_wheels ./pip_wheels
|
||||
|
||||
# set the conda package path
|
||||
echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc
|
||||
|
||||
# set safe directory
|
||||
git config --global --add safe.directory /__w/ColossalAI/ColossalAI
|
||||
|
||||
# get cub package for cuda 10.2
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
- name: Build bdist wheel
|
||||
run: |
|
||||
pip install beautifulsoup4 requests packaging
|
||||
python ./build_colossalai_wheel.py --torch_version $TORCH_VERSIONS
|
||||
env:
|
||||
TORCH_VERSIONS: ${{ matrix.build.torch_version }}
|
||||
- name: 🚀 Deploy
|
||||
uses: garygrossgarten/github-action-scp@release
|
||||
with:
|
||||
local: all_dist
|
||||
remote: ${{ secrets.PRIVATE_PYPI_DIR }}
|
||||
host: ${{ secrets.PRIVATE_PYPI_HOST }}
|
||||
username: ${{ secrets.PRIVATE_PYPI_USER }}
|
||||
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
|
|
@ -20,15 +20,26 @@ jobs:
|
|||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Locate base commit
|
||||
id: locate-base-sha
|
||||
run: |
|
||||
curBranch=$(git rev-parse --abbrev-ref HEAD)
|
||||
commonCommit=$(git merge-base origin/main $curBranch)
|
||||
echo $commonCommit
|
||||
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Find the changed files
|
||||
id: find-changed-files
|
||||
uses: tj-actions/changed-files@v35
|
||||
with:
|
||||
since_last_remote_commit: true
|
||||
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
|
||||
files: |
|
||||
op_builder/**
|
||||
colossalai/kernel/**
|
||||
setup.py
|
||||
|
||||
- name: List changed files
|
||||
run: |
|
||||
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
|
||||
|
@ -75,12 +86,26 @@ jobs:
|
|||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest tests
|
||||
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
NCCL_SHM_DISABLE: 1
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
||||
- name: Collate artifact
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.number }}
|
||||
run: |
|
||||
mkdir report
|
||||
echo $PR_NUMBER > ./report/pr_number
|
||||
mv coverage.xml ./report
|
||||
|
||||
- name: Upload test coverage artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: report
|
||||
path: report/
|
||||
|
||||
- name: Store Cache
|
||||
run: |
|
||||
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
|
||||
|
|
|
@ -1,119 +0,0 @@
|
|||
name: Test Example
|
||||
on:
|
||||
pull_request:
|
||||
# So only the changes in examples folder will trigger jobs below.
|
||||
paths:
|
||||
- 'examples/**'
|
||||
# run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
|
||||
schedule:
|
||||
- cron: '0 16 * * 6'
|
||||
|
||||
jobs:
|
||||
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
|
||||
detect-changed-example:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
name: Check out all files
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Get all changed example files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v35
|
||||
# Using this can trigger action each time a PR is submitted.
|
||||
with:
|
||||
since_last_remote_commit: true
|
||||
- name: setup matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
changedFileName=""
|
||||
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
|
||||
changedFileName="${file}:${changedFileName}"
|
||||
done
|
||||
echo "$changedFileName was changed"
|
||||
res=`python .github/workflows/scripts/changed_example.py --fileNameList $changedFileName`
|
||||
echo "All changed files are $res"
|
||||
loc=$( IFS=',' ; echo "${res[*]}" )
|
||||
echo "$loc"
|
||||
echo "::set-output name=matrix::{\"loc\":$(echo "$loc")}"
|
||||
|
||||
# If no file is changed, it will prompt an error and shows the matrix do not have value.
|
||||
check-all-changed-files:
|
||||
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
|
||||
name: Test each changed example files
|
||||
needs: detect-changed-example
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependancies
|
||||
run: |
|
||||
pip install -r ./requirements/requirements.txt
|
||||
pip install colossalai
|
||||
- name: List all changed example files
|
||||
run: |
|
||||
res=${{ matrix.loc }}
|
||||
cd "${PWD}/examples/${res}"
|
||||
bash test_ci.sh
|
||||
|
||||
# This is for all files' weekly check. Specifically, this job is to find all the directories.
|
||||
matrix_preparation:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
|
||||
name: Prepare Directory List for All files
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: setup matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
res=`python .github/workflows/scripts/weekly_check_example.py`
|
||||
all_loc=$( IFS=',' ; echo "${res[*]}" )
|
||||
echo "$all_loc"
|
||||
echo "::set-output name=matrix::{\"all_loc\":$(echo "$all_loc")}"
|
||||
|
||||
weekly_check:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
|
||||
name: Weekly check all examples
|
||||
needs: matrix_preparation
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Install the requirements
|
||||
run: |
|
||||
pip install -r ./requirements/requirements.txt
|
||||
pip install colossalai
|
||||
- name: Traverse all files
|
||||
run: |
|
||||
dir=${{ matrix.all_loc }}
|
||||
echo "${dir} is current directory"
|
||||
cd "${PWD}/examples/${dir}"
|
||||
bash test_ci.sh
|
|
@ -1,4 +1,4 @@
|
|||
name: Compatibility Test
|
||||
name: Dispatch Compatibility Test
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
|
@ -8,7 +8,7 @@ on:
|
|||
required: true
|
||||
|
||||
jobs:
|
||||
manual_check_matrix_preparation:
|
||||
matrix_preparation:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
|
@ -16,31 +16,24 @@ jobs:
|
|||
name: Check the examples user want
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix-1.outputs.matrix }}
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Get manual directories
|
||||
id: set-matrix-1
|
||||
- name: Set up matrix
|
||||
id: set-matrix
|
||||
env:
|
||||
check_dir: ${{ inputs.example_directory }}
|
||||
run: |
|
||||
all_mannual_check_dir=()
|
||||
for cdi in $check_dir
|
||||
do
|
||||
all_mannual_check_dir+=("\"${cdi}\"")
|
||||
done
|
||||
man_loc=$( IFS=',' ; echo "${all_mannual_check_dir[*]}" )
|
||||
res=`python .github/workflows/scripts/input_check_example.py --fileNameList $man_loc`
|
||||
echo "${res} is file existance. 1 for all exist, -1 for at least one file not exist."
|
||||
if [ res == -1 ];then
|
||||
exit(1)
|
||||
res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir`
|
||||
if [ res == "failure" ];then
|
||||
exit -1
|
||||
fi
|
||||
man_loc="[${man_loc}]"
|
||||
echo "$man_loc"
|
||||
echo "::set-output name=matrix::{\"man_loc\":$(echo "$man_loc")}"
|
||||
dirs="[${check_dir}]"
|
||||
echo "Testing examples in $dirs"
|
||||
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
|
||||
|
||||
manual_check:
|
||||
test_example:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
|
@ -52,16 +45,19 @@ jobs:
|
|||
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Install the requirements
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -r ./requirements/requirements.txt
|
||||
pip install colossalai
|
||||
- name: Traverse all files
|
||||
pip install -v .
|
||||
- name: Test the example
|
||||
run: |
|
||||
dir=${{ matrix.man_loc }}
|
||||
echo "${dir} is current directory"
|
||||
dir=${{ matrix.directory }}
|
||||
echo "Testing ${dir} now"
|
||||
cd "${PWD}/examples/${dir}"
|
||||
bash test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
|
@ -8,11 +8,10 @@ on:
|
|||
types:
|
||||
- closed
|
||||
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Draft Release Post
|
||||
if: github.repository == 'hpcaitech/ColossalAI'
|
||||
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
# the PR branch and the hpcaitech/colossal-ai main branch
|
||||
# must share a common commit, we need to locate that commit,
|
||||
# which is the commit checked-out or forked when the PR branch is created
|
||||
# such that we can look for files changed since that commit
|
||||
- name: Locate base commit
|
||||
id: locate-base-sha
|
||||
run: |
|
||||
curBranch=$(git rev-parse --abbrev-ref HEAD)
|
||||
commonCommit=$(git merge-base origin/main $curBranch)
|
||||
echo $commonCommit
|
||||
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Find the changed files
|
||||
id: find-changed-files
|
||||
uses: tj-actions/changed-files@v35
|
||||
with:
|
||||
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
|
||||
|
||||
- name: List all changed files
|
||||
run: |
|
||||
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
|
||||
echo "$file was changed"
|
||||
done
|
||||
|
||||
- uses: actions/setup-python@v3
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: ${{ runner.os }}-pre-commit-hooks
|
||||
|
||||
- name: Set up pre-commit
|
||||
run: |
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
|
||||
- name: Run pre-commit on Changed Files
|
||||
id: precommit
|
||||
run: |
|
||||
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
|
||||
echo "======= running pre-commit on ${file} ======="
|
||||
pre-commit run --files $file
|
||||
done
|
||||
|
||||
- name: Save PR number
|
||||
if: always()
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.number }}
|
||||
run: |
|
||||
mkdir -p ./pr
|
||||
echo $PR_NUMBER > ./pr/pr_number
|
||||
- uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: pr_number
|
||||
path: pr/
|
|
@ -2,13 +2,16 @@ name: Publish Docker Image to DockerHub
|
|||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
release:
|
||||
types: [published]
|
||||
pull_request:
|
||||
paths:
|
||||
- 'version.txt'
|
||||
types:
|
||||
- closed
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Publish Docker Image to DockerHub
|
||||
if: github.repository == 'hpcaitech/ColossalAI'
|
||||
if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: "hpcaitech/docker-in-docker:latest"
|
||||
|
@ -18,23 +21,17 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Build Docker
|
||||
id: build
|
||||
run: |
|
||||
version=$(cat version.txt)
|
||||
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker
|
||||
tag=hpcaitech/colossalai:$version
|
||||
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker
|
||||
echo "tag=${tag}" >> $GITHUB_OUTPUT
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
|
||||
with:
|
||||
images: hpcaitech/colossalai
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
- name: Push Docker image
|
||||
run: |
|
||||
docker push ${{ steps.build.outputs.tag }}
|
||||
|
|
|
@ -1,73 +1,29 @@
|
|||
name: Release bdist wheel for Nightly versions
|
||||
name: Publish Nightly Version to PyPI
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# run at 00:00 of every Sunday
|
||||
- cron: '0 0 * * 6'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
|
||||
|
||||
jobs:
|
||||
matrix_preparation:
|
||||
name: Prepare Container List
|
||||
build-n-publish:
|
||||
if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI'
|
||||
name: Build and publish Python 🐍 distributions 📦 to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- id: set-matrix
|
||||
run: |
|
||||
matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]"
|
||||
echo $matrix
|
||||
echo "::set-output name=matrix::{\"container\":$(echo $matrix)}"
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
build:
|
||||
name: Release bdist wheels
|
||||
needs: matrix_preparation
|
||||
if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
# cub is for cuda 10.2
|
||||
- name: Copy scripts and checkout
|
||||
run: |
|
||||
cp -r ./.github/workflows/scripts/* ./
|
||||
ln -s /github/home/pip_wheels ./pip_wheels
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
- name: Build bdist wheel
|
||||
run: |
|
||||
pip install beautifulsoup4 requests packaging
|
||||
python ./build_colossalai_wheel.py --nightly
|
||||
- name: 🚀 Deploy
|
||||
uses: garygrossgarten/github-action-scp@release
|
||||
with:
|
||||
local: all_dist
|
||||
remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
|
||||
host: ${{ secrets.PRIVATE_PYPI_HOST }}
|
||||
username: ${{ secrets.PRIVATE_PYPI_USER }}
|
||||
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
|
||||
remove_old_build:
|
||||
name: Remove old nightly build
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: executing remote ssh commands using password
|
||||
uses: appleboy/ssh-action@master
|
||||
env:
|
||||
BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
|
||||
with:
|
||||
host: ${{ secrets.PRIVATE_PYPI_HOST }}
|
||||
username: ${{ secrets.PRIVATE_PYPI_USER }}
|
||||
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
|
||||
envs: BUILD_DIR
|
||||
script: |
|
||||
cd $BUILD_DIR
|
||||
find . -type f -mtime +0 -exec rm -f {} +
|
||||
script_stop: true
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.8.14'
|
||||
|
||||
- run: NIGHTLY=1 python setup.py sdist build
|
||||
|
||||
# publish to PyPI if executed on the main branch
|
||||
- name: Publish package to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
verbose: true
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
name: Report Precommit Failure
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [pre-commit]
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
# comment with a message on how to do pre-commit
|
||||
# if the pre-commit check was not passed
|
||||
report-precommit-failure:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.workflow_run.conclusion == 'failure' }}
|
||||
steps:
|
||||
- name: 'Download artifact'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: context.payload.workflow_run.id,
|
||||
});
|
||||
let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
|
||||
return artifact.name == "pr_number"
|
||||
})[0];
|
||||
let download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: matchArtifact.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
let fs = require('fs');
|
||||
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/pr_number.zip`, Buffer.from(download.data));
|
||||
|
||||
- name: 'Unzip artifact'
|
||||
run: unzip pr_number.zip
|
||||
|
||||
- name: 'Comment on PR'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
let fs = require('fs');
|
||||
let issue_number = Number(fs.readFileSync('./pr_number'));
|
||||
let owner = context.repo.owner;
|
||||
let repo = context.repo.repo;
|
||||
let run_id = context.payload.workflow_run.id;
|
||||
let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}`
|
||||
let body = `
|
||||
Your pre-commit check failed, follow the steps to run pre-commit on your file for code style consistency.
|
||||
|
||||
1. install pre-commit via "pip install pre-commit"
|
||||
2. install pre-commit hooks via "pre-commit install"
|
||||
3. run pre-commit on file with format error via "pre-commit run --files path" by replacing "path" with the actual file path
|
||||
4. commit and push to your branch
|
||||
|
||||
View your job at ${run_url}.
|
||||
Read our "CONTRIBUTING.md" for more reference to the code style.
|
||||
`;
|
||||
await github.rest.issues.createComment({
|
||||
owner: owner,
|
||||
repo: repo,
|
||||
issue_number: issue_number,
|
||||
body: body
|
||||
});
|
|
@ -0,0 +1,74 @@
|
|||
name: Report Test Coverage
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [Build]
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
report-test-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: 'Download artifact'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: context.payload.workflow_run.id,
|
||||
});
|
||||
let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
|
||||
return artifact.name == "report"
|
||||
})[0];
|
||||
let download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: matchArtifact.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
let fs = require('fs');
|
||||
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/report.zip`, Buffer.from(download.data));
|
||||
|
||||
- name: 'Unzip artifact'
|
||||
run: |
|
||||
unzip report.zip
|
||||
|
||||
- name: Code Coverage Report
|
||||
uses: irongut/CodeCoverageSummary@v1.3.0
|
||||
with:
|
||||
filename: coverage.xml
|
||||
badge: true
|
||||
format: markdown
|
||||
hide_branch_rate: false
|
||||
hide_complexity: false
|
||||
indicators: true
|
||||
output: both
|
||||
thresholds: '80 90'
|
||||
|
||||
- name: Make Coverage Report Collapsable
|
||||
run: |
|
||||
sed -i '2 i <details>' code-coverage-results.md
|
||||
sed -i '3 i <summary>Click me to view the complete report</summary>' code-coverage-results.md
|
||||
echo "</details>" >> code-coverage-results.md
|
||||
|
||||
- name: 'Comment on PR'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
let fs = require('fs');
|
||||
let issue_number = Number(fs.readFileSync('./pr_number'));
|
||||
let owner = context.repo.owner;
|
||||
let repo = context.repo.repo;
|
||||
let run_id = context.payload.workflow_run.id;
|
||||
let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}`
|
||||
let body = fs.readFileSync('./code-coverage-results.md', {encoding:'utf8', flag:'r'})
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: owner,
|
||||
repo: repo,
|
||||
issue_number: issue_number,
|
||||
body: body
|
||||
});
|
|
@ -0,0 +1,27 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def check_inputs(input_list):
|
||||
for path in input_list:
|
||||
real_path = os.path.join('examples', path)
|
||||
if not os.path.exists(real_path):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(",")
|
||||
is_correct = check_inputs(name_list)
|
||||
|
||||
if is_correct:
|
||||
print('success')
|
||||
else:
|
||||
print('failure')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -5,9 +5,9 @@ def show_files(path, all_files):
|
|||
# Traverse all the folder/file in current directory
|
||||
file_list = os.listdir(path)
|
||||
# Determine the element is folder or file. If file, pass it into list, if folder, recurse.
|
||||
for file in file_list:
|
||||
for file_name in file_list:
|
||||
# Get the abs directory using os.path.join() and store into cur_path.
|
||||
cur_path = os.path.join(path, file)
|
||||
cur_path = os.path.join(path, file_name)
|
||||
# Determine whether folder
|
||||
if os.path.isdir(cur_path):
|
||||
show_files(cur_path, all_files)
|
||||
|
@ -26,9 +26,8 @@ def main():
|
|||
for file_loc in contents:
|
||||
split_loc = file_loc.split('/')
|
||||
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
|
||||
if len(split_loc) - split_loc.index('examples') >= 3:
|
||||
tmp_loc = split_loc[(split_loc.index('examples') + 1):(split_loc.index('examples') + 3)]
|
||||
re_loc = join(tmp_loc, '/')
|
||||
if len(split_loc) >= 4:
|
||||
re_loc = '/'.join(split_loc[1:3])
|
||||
if re_loc not in all_loc:
|
||||
all_loc.append(re_loc)
|
||||
print(all_loc)
|
|
@ -3,14 +3,19 @@ import argparse
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--fileNameList', type=str)
|
||||
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(":")
|
||||
folder_need_check = set()
|
||||
for loc in name_list:
|
||||
# Find only the sub-folder of 'example' folder
|
||||
# Find only the sub-sub-folder of 'example' folder
|
||||
# the examples folder structure is like
|
||||
# - examples
|
||||
# - area
|
||||
# - application
|
||||
# - file
|
||||
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
|
||||
folder_need_check.add(loc.split("/")[1] + "/" + loc.split("/")[2])
|
||||
folder_need_check.add('/'.join(loc.split("/")[1:3]))
|
||||
# Output the result using print. Then the shell can get the values.
|
||||
print(list(folder_need_check))
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def detect_correct(loc_li):
|
||||
for loc in loc_li:
|
||||
real_loc = 'examples/' + eval(loc)
|
||||
if not os.path.exists(real_loc):
|
||||
return -1
|
||||
return 1
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--fileNameList', type=str)
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(",")
|
||||
result = detect_correct(name_list)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,18 @@
|
|||
name: 'issue-translator'
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: usthe/issues-translate-action@v2.7
|
||||
with:
|
||||
IS_MODIFY_TITLE: false
|
||||
# not require, default false, . Decide whether to modify the issue title
|
||||
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
|
||||
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
|
||||
# not require. Customize the translation robot prefix message.
|
|
@ -151,3 +151,7 @@ colossalai/version.py
|
|||
|
||||
# ignore python interface defition file
|
||||
.pyi
|
||||
|
||||
# ignore coverage test file
|
||||
coverage.lcov
|
||||
coverage.xml
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
|
||||
Colossal-AI: 一个面向大模型时代的通用深度学习系统
|
||||
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
|
||||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
|
||||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
|
||||
|
||||
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||
|
@ -35,7 +35,7 @@
|
|||
<li><a href="#为何选择-Colossal-AI">为何选择 Colossal-AI</a> </li>
|
||||
<li><a href="#特点">特点</a> </li>
|
||||
<li>
|
||||
<a href="#并行训练样例展示">并行训练样例展示</a>
|
||||
<a href="#并行训练样例展示">并行训练样例展示</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-3">GPT-3</a></li>
|
||||
<li><a href="#GPT-2">GPT-2</a></li>
|
||||
|
@ -47,14 +47,14 @@
|
|||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
|
||||
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-2-Single">GPT-2</a></li>
|
||||
<li><a href="#PaLM-Single">PaLM</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
|
||||
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-3-Inference">GPT-3</a></li>
|
||||
<li><a href="#OPT-Serving">1750亿参数OPT在线推理服务</a></li>
|
||||
|
@ -62,7 +62,7 @@
|
|||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
|
||||
<ul>
|
||||
<li><a href="#AIGC">AIGC: 加速 Stable Diffusion</a></li>
|
||||
<li><a href="#生物医药">生物医药: 加速AlphaFold蛋白质结构预测</a></li>
|
||||
|
@ -131,7 +131,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/(updated)GPT-2.png" width=800>
|
||||
|
||||
- 用相同的硬件训练24倍大的模型
|
||||
- 超3倍的吞吐量
|
||||
- 超3倍的吞吐量
|
||||
|
||||
### BERT
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BERT.png" width=800/>
|
||||
|
@ -145,7 +145,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
|
||||
|
||||
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型,由于完全公开了预训练参数权重,因此促进了下游任务和应用部署的发展。
|
||||
- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
|
||||
- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/zh-Hans/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md)
|
||||
|
||||
请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI-Examples) 以了解详情。
|
||||
|
||||
|
@ -199,7 +199,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_serving.png" width=800/>
|
||||
</p>
|
||||
|
||||
- [OPT推理服务](https://service.colossalai.org/opt): 无需注册,免费体验1750亿参数OPT在线推理服务
|
||||
- [OPT推理服务](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/zh-Hans/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md): 无需注册,免费体验1750亿参数OPT在线推理服务
|
||||
|
||||
<p id="BLOOM-Inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
|
||||
|
@ -255,6 +255,28 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
|
||||
## 安装
|
||||
|
||||
### 从PyPI安装
|
||||
|
||||
您可以用下面的命令直接从PyPI上下载并安装Colossal-AI。我们默认不会安装PyTorch扩展包
|
||||
|
||||
```bash
|
||||
pip install colossalai
|
||||
```
|
||||
|
||||
但是,如果你想在安装时就直接构建PyTorch扩展,您可以设置环境变量`CUDA_EXT=1`.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
**否则,PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。**
|
||||
|
||||
与此同时,我们也每周定时发布Nightly版本,这能让你提前体验到新的feature和bug fix。你可以通过以下命令安装Nightly版本。
|
||||
|
||||
```bash
|
||||
pip install colossalai-nightly
|
||||
```
|
||||
|
||||
### 从官方安装
|
||||
|
||||
您可以访问我们[下载](https://www.colossalai.org/download)页面来安装Colossal-AI,在这个页面上发布的版本都预编译了CUDA扩展。
|
||||
|
@ -274,10 +296,10 @@ pip install -r requirements/requirements.txt
|
|||
pip install .
|
||||
```
|
||||
|
||||
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装):
|
||||
我们默认在`pip install`时不安装PyTorch扩展,而是在运行时临时编译,如果你想要提前安装这些扩展的话(在使用融合优化器时会用到),可以使用一下命令。
|
||||
|
||||
```shell
|
||||
NO_CUDA_EXT=1 pip install .
|
||||
CUDA_EXT=1 pip install .
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
@ -327,6 +349,11 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
|
|||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
|
||||
## CI/CD
|
||||
|
||||
我们使用[GitHub Actions](https://github.com/features/actions)来自动化大部分开发以及部署流程。如果想了解这些工作流是如何运行的,请查看这个[文档](.github/workflows/README.md).
|
||||
|
||||
|
||||
## 引用我们
|
||||
|
||||
```
|
||||
|
@ -338,4 +365,6 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
|
|||
}
|
||||
```
|
||||
|
||||
Colossal-AI 已被 [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/) 等顶级会议录取为官方教程。
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
|
39
README.md
39
README.md
|
@ -149,7 +149,7 @@ distributed training and inference in a few lines.
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
|
||||
|
||||
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights.
|
||||
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
|
||||
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/en/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md)
|
||||
|
||||
Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details.
|
||||
|
||||
|
@ -202,7 +202,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_serving.png" width=800/>
|
||||
</p>
|
||||
|
||||
- [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
|
||||
- [OPT Serving](https://github.com/hpcaitech/ColossalAI-Documentation/blob/main/i18n/en/docusaurus-plugin-content-docs/current/advanced_tutorials/opt_service.md): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
|
||||
|
||||
<p id="BLOOM-Inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
|
||||
|
@ -257,9 +257,32 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
|
|||
|
||||
## Installation
|
||||
|
||||
### Install from PyPI
|
||||
|
||||
You can easily install Colossal-AI with the following command. **By defualt, we do not build PyTorch extensions during installation.**
|
||||
|
||||
```bash
|
||||
pip install colossalai
|
||||
```
|
||||
|
||||
However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
**Otherwise, CUDA kernels will be built during runtime when you actually need it.**
|
||||
|
||||
We also keep release the nightly version to PyPI on a weekly basis. This allows you to access the unreleased features and bug fixes in the main branch.
|
||||
Installation can be made via
|
||||
|
||||
```bash
|
||||
pip install colossalai-nightly
|
||||
```
|
||||
|
||||
### Download From Official Releases
|
||||
|
||||
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions.
|
||||
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built PyTorch extensions.
|
||||
|
||||
|
||||
### Download From Source
|
||||
|
@ -270,9 +293,6 @@ You can visit the [Download](https://www.colossalai.org/download) page to downlo
|
|||
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||
cd ColossalAI
|
||||
|
||||
# install dependency
|
||||
pip install -r requirements/requirements.txt
|
||||
|
||||
# install colossalai
|
||||
pip install .
|
||||
```
|
||||
|
@ -333,6 +353,11 @@ Thanks so much to all of our amazing contributors!
|
|||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
|
||||
## CI/CD
|
||||
|
||||
We leverage the power of [GitHub Actions](https://github.com/features/actions) to automate our development, release and deployment workflows. Please check out this [documentation](.github/workflows/README.md) on how the automated workflows are operated.
|
||||
|
||||
|
||||
## Cite Us
|
||||
|
||||
```
|
||||
|
@ -344,4 +369,6 @@ Thanks so much to all of our amazing contributors!
|
|||
}
|
||||
```
|
||||
|
||||
Colossal-AI has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc.
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
|
|
@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _act_annotataion_pass(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
This pass is used to add the act annotation to the new inserted nodes.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node.meta, 'activation_checkpoint'):
|
||||
from .runtime_preparation_pass import size_processing
|
||||
|
||||
user_act_annotation = -1
|
||||
input_act_annotation = -1
|
||||
for user_node in node.users.keys():
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
user_act_annotation = user_node.meta['activation_checkpoint']
|
||||
break
|
||||
for input_node in node._input_nodes.keys():
|
||||
if 'activation_checkpoint' in input_node.meta:
|
||||
input_act_annotation = input_node.meta['activation_checkpoint']
|
||||
break
|
||||
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
|
||||
node.meta['activation_checkpoint'] = user_act_annotation
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
|
|||
)
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
|
|||
into the forward function.
|
||||
'''
|
||||
|
||||
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
||||
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
||||
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
|
||||
'''
|
||||
Args:
|
||||
|
@ -59,18 +60,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
|
|||
pass
|
||||
|
||||
|
||||
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
|
||||
'''
|
||||
This method is used to search the best logical mesh shape for the given world size
|
||||
based on the alpha_beta_dict.
|
||||
|
||||
For example:
|
||||
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
|
||||
'''
|
||||
# TODO: implement this function
|
||||
return (world_size, 1)
|
||||
|
||||
|
||||
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
|
||||
'''
|
||||
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
|
||||
|
@ -93,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
|||
return strategies_constructor
|
||||
|
||||
|
||||
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
||||
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
||||
'''
|
||||
This method is used to solve the best solution for the given graph.
|
||||
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
||||
|
@ -109,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
|
|||
return solution
|
||||
|
||||
|
||||
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||
def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
'''
|
||||
This method is used to transform the original graph to the sharded graph.
|
||||
|
@ -127,39 +116,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
|
|||
|
||||
|
||||
def initialize_device_mesh(world_size: int = -1,
|
||||
physical_devices: List[int] = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None):
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None):
|
||||
'''
|
||||
This method is used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
world_size(optional): the size of device mesh. If the world_size is -1,
|
||||
world_size: the size of device mesh. If the world_size is -1,
|
||||
the world size will be set to the number of GPUs in the current machine.
|
||||
physical_devices: the physical devices used to initialize the device mesh.
|
||||
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||
generated by profile_alpha_beta function.
|
||||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
mesh shape.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
'''
|
||||
# if world_size is not set, use the world size from torch.distributed
|
||||
if world_size == -1:
|
||||
world_size = dist.get_world_size()
|
||||
device1d = [i for i in range(world_size)]
|
||||
|
||||
if physical_devices is None:
|
||||
physical_devices = [i for i in range(world_size)]
|
||||
physical_mesh = torch.tensor(physical_devices)
|
||||
|
||||
if alpha_beta_dict is None:
|
||||
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
|
||||
alpha_beta_dict = profile_alpha_beta(device1d)
|
||||
ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||
alpha_beta_dict = ab_profiler.alpha_beta_dict
|
||||
else:
|
||||
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
|
||||
|
||||
if logical_mesh_shape is None:
|
||||
if logical_mesh_shape is None and logical_mesh_id is None:
|
||||
# search for the best logical mesh shape
|
||||
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
|
||||
logical_mesh_id = ab_profiler.search_best_logical_mesh()
|
||||
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
|
||||
logical_mesh_shape = logical_mesh_id.shape
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
|
||||
|
||||
elif logical_mesh_shape is not None and logical_mesh_id is None:
|
||||
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
|
||||
physical_mesh = torch.tensor(device1d)
|
||||
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
|
||||
mesh_shape=logical_mesh_shape,
|
||||
logical_mesh_id=logical_mesh_id,
|
||||
mesh_alpha=mesh_alpha,
|
||||
mesh_beta=mesh_beta,
|
||||
init_process_group=True)
|
||||
|
@ -192,10 +198,10 @@ def initialize_model(model: nn.Module,
|
|||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||
if load_solver_solution:
|
||||
|
@ -224,6 +230,7 @@ def autoparallelize(model: nn.Module,
|
|||
data_process_func: callable = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
|
@ -245,6 +252,7 @@ def autoparallelize(model: nn.Module,
|
|||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -254,7 +262,9 @@ def autoparallelize(model: nn.Module,
|
|||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
'''
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
|
||||
logical_mesh_shape=logical_mesh_shape,
|
||||
logical_mesh_id=logical_mesh_id)
|
||||
if meta_args is None:
|
||||
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
||||
|
||||
|
@ -263,7 +273,7 @@ def autoparallelize(model: nn.Module,
|
|||
device_mesh,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solver_solution_path=solver_solution_path,
|
||||
solution_path=solver_solution_path,
|
||||
return_solution=return_solution,
|
||||
memory_budget=memory_budget)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from .layer_norm_handler import LayerNormModuleHandler
|
|||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .option import ShardOption
|
||||
from .output_handler import OutputHandler
|
||||
from .placeholder_handler import PlaceholderHandler
|
||||
from .registry import operator_registry
|
||||
|
@ -27,5 +28,5 @@ __all__ = [
|
|||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption'
|
||||
]
|
||||
|
|
|
@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
|||
return OperationDataType.ARG
|
||||
|
||||
def _get_arg_value(idx):
|
||||
non_tensor = False
|
||||
if isinstance(self.node.args[idx], Node):
|
||||
meta_data = self.node.args[idx]._meta_data
|
||||
# The meta_data of node type argument could also possibly be a non-tensor object.
|
||||
if not isinstance(meta_data, torch.Tensor):
|
||||
assert isinstance(meta_data, (int, float))
|
||||
meta_data = torch.Tensor([meta_data]).to('meta')
|
||||
non_tensor = True
|
||||
|
||||
else:
|
||||
# this is in fact a real data like int 1
|
||||
# but we can deem it as meta data
|
||||
# as it won't affect the strategy generation
|
||||
assert isinstance(self.node.args[idx], (int, float))
|
||||
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
||||
return meta_data
|
||||
non_tensor = True
|
||||
|
||||
input_meta_data = _get_arg_value(0)
|
||||
other_meta_data = _get_arg_value(1)
|
||||
return meta_data, non_tensor
|
||||
|
||||
input_meta_data, non_tensor_input = _get_arg_value(0)
|
||||
other_meta_data, non_tensor_other = _get_arg_value(1)
|
||||
output_meta_data = self.node._meta_data
|
||||
|
||||
# we need record op_data with non-tensor data in this list,
|
||||
# and filter the non-tensor op_data in post_process.
|
||||
self.non_tensor_list = []
|
||||
# assert False
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=_get_op_data_type(input_meta_data),
|
||||
data=input_meta_data,
|
||||
|
@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
|||
type=OperationDataType.OUTPUT,
|
||||
data=output_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
if non_tensor_input:
|
||||
self.non_tensor_list.append(input_op_data)
|
||||
if non_tensor_other:
|
||||
self.non_tensor_list.append(other_op_data)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
|
@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
|||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
for op_name, op_data in op_data_mapping.items():
|
||||
if not isinstance(op_data.data, torch.Tensor):
|
||||
if op_data in self.non_tensor_list:
|
||||
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
||||
strategy.sharding_specs.pop(op_data)
|
||||
|
||||
else:
|
||||
# convert the logical sharding spec to physical sharding spec if broadcast
|
||||
# e.g. torch.rand(4, 4) + torch.rand(4)
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -35,12 +36,14 @@ class NodeHandler(ABC):
|
|||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
shard_option: ShardOption = ShardOption.STANDARD,
|
||||
) -> None:
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shard_option = shard_option
|
||||
|
||||
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
|
||||
"""
|
||||
|
@ -181,6 +184,21 @@ class NodeHandler(ABC):
|
|||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
check_sharding_spec_validity(sharding_spec, op_data.data)
|
||||
|
||||
remove_strategy_list = []
|
||||
for strategy in self.strategies_vector:
|
||||
shard_level = 0
|
||||
for op_data, sharding_spec in strategy.sharding_specs.items():
|
||||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
|
||||
shard_level += len(shard_axis)
|
||||
if self.shard_option == ShardOption.SHARD and shard_level == 0:
|
||||
remove_strategy_list.append(strategy)
|
||||
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
|
||||
remove_strategy_list.append(strategy)
|
||||
|
||||
for strategy in remove_strategy_list:
|
||||
self.strategies_vector.remove(strategy)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from enum import Enum
|
||||
|
||||
__all__ = ['ShardOption']
|
||||
|
||||
|
||||
class ShardOption(Enum):
|
||||
"""
|
||||
This enum class is to define the shard level required in node strategies.
|
||||
|
||||
Notes:
|
||||
STANDARD: We do not add any extra shard requirements.
|
||||
SHARD: We require the node to be shard using at least one device mesh axis.
|
||||
FULL_SHARD: We require the node to be shard using all device mesh axes.
|
||||
"""
|
||||
STANDARD = 0
|
||||
SHARD = 1
|
||||
FULL_SHARD = 2
|
|
@ -0,0 +1,523 @@
|
|||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||
"""
|
||||
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
|
||||
|
||||
Args:
|
||||
chunk_dim (int)
|
||||
chunk_indice_name (str): chunk indice name
|
||||
shape (List): node shape
|
||||
|
||||
Returns:
|
||||
new_shape (str): return slice
|
||||
"""
|
||||
new_shape = "["
|
||||
for idx, _ in enumerate(shape):
|
||||
if idx == chunk_dim:
|
||||
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
|
||||
else:
|
||||
new_shape += ":"
|
||||
new_shape += ", "
|
||||
new_shape = new_shape[:-2] + "]"
|
||||
return new_shape
|
||||
|
||||
|
||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||
"""
|
||||
Generate chunk loop start
|
||||
|
||||
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
|
||||
chunk_size = 32
|
||||
for chunk_idx in range(0, 100, 32):
|
||||
......
|
||||
|
||||
Args:
|
||||
chunk_input (List[Node]): chunk input node
|
||||
chunk_output (Node): chunk output node
|
||||
chunk_ouput_dim (int): chunk output node chunk dim
|
||||
chunk_size (int): chunk size. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
context (str): generated str
|
||||
"""
|
||||
input_node = chunk_input[0]
|
||||
out_shape = get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
|
||||
(out_str, input_node.name, input_node.name, chunk_size))
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(
|
||||
chunk_inputs: List[Node],
|
||||
chunk_non_compute_inputs: List[Node],
|
||||
chunk_outputs: Node,
|
||||
chunk_outputs_dim: int,
|
||||
node_list: List[Node],
|
||||
) -> str:
|
||||
"""
|
||||
Generate chunk loop end
|
||||
|
||||
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
|
||||
output_node = chunk_result; xx = None; xx = None
|
||||
|
||||
Args:
|
||||
chunk_inputs (List[Node]): chunk input node
|
||||
chunk_non_compute_inputs (List[Node]): input node without chunk
|
||||
chunk_outputs (Node): chunk output node
|
||||
chunk_outputs_dim (int): chunk output node chunk dim
|
||||
node_list (List)
|
||||
|
||||
Returns:
|
||||
context (str): generated str
|
||||
"""
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (
|
||||
chunk_slice,
|
||||
chunk_outputs_name,
|
||||
chunk_outputs_name,
|
||||
)
|
||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||
|
||||
# determine if its the last use for chunk input
|
||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||
context += "; %s = None" % chunk_input.name
|
||||
|
||||
context += "\n"
|
||||
return context
|
||||
|
||||
|
||||
def _replace_name(context: str, name_from: str, name_to: str) -> str:
|
||||
"""
|
||||
replace node name
|
||||
"""
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
|
||||
for p in patterns:
|
||||
source = p[0] + name_from + p[1]
|
||||
target = p[0] + name_to + p[1]
|
||||
if source in context:
|
||||
context = context.replace(source, target)
|
||||
break
|
||||
return context
|
||||
|
||||
|
||||
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
|
||||
"""
|
||||
replace reshape size, some may have changed due to chunk
|
||||
"""
|
||||
if node_name not in reshape_size_dict:
|
||||
return context
|
||||
context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
|
||||
return context
|
||||
|
||||
|
||||
def _replace_ones_like(
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List[Dict],
|
||||
region_idx: int,
|
||||
node_idx: int,
|
||||
node: Node,
|
||||
body: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for new tensor op such as ones like
|
||||
"""
|
||||
if "ones_like" in node.name:
|
||||
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
|
||||
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
|
||||
return body
|
||||
|
||||
|
||||
def _replace_input_node(
|
||||
chunk_inputs: List[Node],
|
||||
region_idx: int,
|
||||
chunk_inputs_dim: Dict,
|
||||
node_idx: int,
|
||||
body: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for input nodes
|
||||
"""
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
|
||||
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
|
||||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(
|
||||
body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
):
|
||||
"""
|
||||
Emit code with chunk according to chunk_infos.
|
||||
|
||||
It will generate a for loop in chunk regions, and
|
||||
replace inputs and outputs of regions with chunked variables.
|
||||
|
||||
Args:
|
||||
body: forward code
|
||||
nodes: graph.nodes
|
||||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
search_chunk: the class to search all chunks
|
||||
chunk_infos: store all information about all chunks.
|
||||
"""
|
||||
node_list = list(nodes)
|
||||
|
||||
# chunk region
|
||||
chunk_starts = [i["region"][0] for i in chunk_infos]
|
||||
chunk_ends = [i["region"][1] for i in chunk_infos]
|
||||
|
||||
# chunk inputs
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
# chunk outputs
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||
|
||||
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
|
||||
while node_idx < len(node_list):
|
||||
node = node_list[node_idx]
|
||||
|
||||
# if is chunk start, generate for loop start
|
||||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
region_idx = chunk_starts.index(node_idx)
|
||||
body.append(
|
||||
_gen_loop_start(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_infos[region_idx]["chunk_size"],
|
||||
))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
|
||||
# ones like
|
||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# reassgin reshape size
|
||||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
node_list,
|
||||
))
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
|
||||
def __init__(self,
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False) -> None:
|
||||
super().__init__()
|
||||
# find the chunk regions
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
if print_progress:
|
||||
get_logger().info("AutoChunk start codegen")
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
wrapped_fns: Dict[str, None] = {}
|
||||
|
||||
# Wrap string in list to pass by reference
|
||||
maybe_return_annotation: List[str] = [""]
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
return _get_qualified_name(obj)
|
||||
|
||||
# normalize the name hint to get a proper identifier
|
||||
global_name = namespace.create_name(name_hint, obj)
|
||||
|
||||
if global_name in globals_:
|
||||
assert globals_[global_name] is obj
|
||||
return global_name
|
||||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
|
||||
def type_repr(o: Any):
|
||||
if o == ():
|
||||
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
||||
return "()"
|
||||
|
||||
typename = _type_repr(o)
|
||||
|
||||
if hasattr(o, "__origin__"):
|
||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||
|
||||
if hasattr(o, "__args__"):
|
||||
# Assign global names for each of the inner type variables.
|
||||
args = [type_repr(arg) for arg in o.__args__]
|
||||
|
||||
if len(args) == 0:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python < 3.9
|
||||
return origin_typename
|
||||
|
||||
return f'{origin_typename}[{",".join(args)}]'
|
||||
else:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python 3.9+
|
||||
return origin_typename
|
||||
|
||||
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
return repr(arg)
|
||||
|
||||
args_s = ", ".join(_get_repr(a) for a in args)
|
||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f"{args_s}, {kwargs_s}"
|
||||
return args_s or kwargs_s
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
# of a given node. This represents the *last* use of the node in the
|
||||
# execution order of the program, which we will use to free unused
|
||||
# values
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
delete_free_var_from_last_use(user_to_last_uses)
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body, to_keep=[]):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
of the code is optimal.
|
||||
"""
|
||||
if user.op == "placeholder":
|
||||
return
|
||||
if user.op == "output":
|
||||
body.append("\n")
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
|
||||
body.append(f"; {to_delete_str}\n")
|
||||
else:
|
||||
body.append("\n")
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
|
||||
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
body.append(f"{repr(node)} = {raw_name}\n")
|
||||
return
|
||||
elif node.op == "call_method":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})")
|
||||
return
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
|
||||
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier() and len(node.args) == 2):
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
|
||||
return
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
|
||||
return
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
|
||||
return
|
||||
elif node.op == "output":
|
||||
if node.type is not None:
|
||||
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||
body.append(self.generate_output(node.args[0]))
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(
|
||||
body,
|
||||
nodes,
|
||||
emit_node,
|
||||
delete_unused_values,
|
||||
self.search_chunk,
|
||||
self.chunk_infos,
|
||||
)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
# have been emitted. To continue to have valid Python code, emit a
|
||||
# single pass statement
|
||||
body.append("pass\n")
|
||||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
else:
|
||||
wrap_stmts = ""
|
||||
|
||||
if self._body_transformer:
|
||||
body = self._body_transformer(body)
|
||||
|
||||
for name, value in self.additional_globals():
|
||||
add_global(name, value)
|
||||
|
||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||
# in forward function
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = "".join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = "".join(body)
|
||||
code = "\n".join(" " + line for line in code.split("\n"))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
# print(fn_code)
|
||||
return PythonCode(fn_code, globals_)
|
|
@ -0,0 +1,323 @@
|
|||
import copy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
|
||||
|
||||
|
||||
class EstimateMemory(object):
|
||||
"""
|
||||
Estimate memory with chunk
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
x = x.meta["tensor_meta"]
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
return x
|
||||
|
||||
def _get_output_node(self, n):
|
||||
out_size = activation_size(n.meta["fwd_out"])
|
||||
out_node = [n.name] if out_size > 0 else []
|
||||
return out_size, out_node
|
||||
|
||||
def _get_output_node_size(self, n):
|
||||
return self._get_output_node(n)[0]
|
||||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
new_active.append(n.name)
|
||||
for i in new_active:
|
||||
if i not in active_list and get_node_shape(n) is not None:
|
||||
active_list.append(i)
|
||||
|
||||
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||
delete_size = 0
|
||||
delete_node = []
|
||||
if user.op not in ("output",):
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(user.users) == 0:
|
||||
nodes_to_delete.append(user)
|
||||
if to_keep is not None:
|
||||
keep_list = []
|
||||
for n in nodes_to_delete:
|
||||
if n.name in to_keep:
|
||||
keep_list.append(n)
|
||||
for n in keep_list:
|
||||
if n in nodes_to_delete:
|
||||
nodes_to_delete.remove(n)
|
||||
if len(nodes_to_delete):
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
for i in range(len(out_node)):
|
||||
if out_node[i][0] > 0:
|
||||
delete_node.append(out_node[i][1][0])
|
||||
elif nodes_to_delete[i].op == "placeholder":
|
||||
delete_node.append(nodes_to_delete[i].name)
|
||||
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
|
||||
# delete_node.append(nodes_to_delete[i].name)
|
||||
return delete_size, delete_node
|
||||
|
||||
def _get_delete_node_size(self, user, user_to_last_uses, to_keep):
|
||||
return self._get_delete_node(user, user_to_last_uses, to_keep)[0]
|
||||
|
||||
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
|
||||
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
||||
for i in delete_node:
|
||||
if i in active_list:
|
||||
active_list.remove(i)
|
||||
|
||||
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
|
||||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
return delete_size
|
||||
|
||||
def _get_last_usr(self, nodes):
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
return user_to_last_uses
|
||||
|
||||
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
|
||||
mem = 0
|
||||
not_contiguous_ops = ["permute"]
|
||||
inherit_contiguous_ops = ["transpose", "view"]
|
||||
|
||||
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
mem += self._get_output_node_size(n)
|
||||
elif node.op == "call_module":
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# module will just make origin tensor to contiguous
|
||||
if delete:
|
||||
not_contiguous_list.remove(n)
|
||||
elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
|
||||
if node not in not_contiguous_list:
|
||||
not_contiguous_list.append(node)
|
||||
return mem
|
||||
|
||||
def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
|
||||
if node not in chunk_node_dim:
|
||||
return 1.0
|
||||
node_shape = get_node_shape(node)
|
||||
chunk_dim = chunk_node_dim[node]["chunk_dim"]
|
||||
if chunk_dim is None:
|
||||
return 1.0
|
||||
else:
|
||||
return float(chunk_size) / node_shape[chunk_dim]
|
||||
|
||||
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
|
||||
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
|
||||
# return 0
|
||||
if user.op in ("placeholder", "output"):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(user.users) == 0:
|
||||
nodes_to_delete.append(user)
|
||||
delete_size = 0
|
||||
for n in nodes_to_delete:
|
||||
if n.name in chunk_inputs_names:
|
||||
continue
|
||||
delete_size += self._get_output_node_size(n) * chunk_ratio
|
||||
return delete_size
|
||||
|
||||
def _print_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
print("%s:%.2f \t" % (n.name, l), end="")
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
|
||||
def _print_compute_op_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
if n.op in ["placeholder", "get_attr", "output"]:
|
||||
continue
|
||||
if any(i in n.name for i in ["getitem", "getattr"]):
|
||||
continue
|
||||
print("%s:%.2f \t" % (n.name, l), end="")
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
|
||||
def estimate_chunk_inference_mem(
|
||||
self,
|
||||
node_list: List,
|
||||
chunk_infos=None,
|
||||
print_mem=False,
|
||||
):
|
||||
"""
|
||||
Estimate inference memory with chunk
|
||||
|
||||
Args:
|
||||
node_list (List): _description_
|
||||
chunk_infos (Dict): Chunk information. Defaults to None.
|
||||
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
|
||||
|
||||
Returns:
|
||||
act_memory_peak_log (List): peak memory of every node
|
||||
act_memory_after_node_log (List): memory after excuting every node
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
nodes generated but not deleted.
|
||||
"""
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
|
||||
use_chunk = True if chunk_infos is not None else False
|
||||
chunk_within = False
|
||||
chunk_region_idx = None
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_inputs_names = []
|
||||
|
||||
if use_chunk:
|
||||
chunk_regions = [i["region"] for i in chunk_infos]
|
||||
chunk_starts = [i[0] for i in chunk_regions]
|
||||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
||||
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node,
|
||||
chunk_node_dim[chunk_region_idx],
|
||||
chunk_sizes[chunk_region_idx],
|
||||
)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == "placeholder":
|
||||
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# skip output
|
||||
elif node.op == "output":
|
||||
continue
|
||||
# no change for non compute node
|
||||
elif is_non_memory_node(node):
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# node is a compute op
|
||||
# calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
|
||||
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
|
||||
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
|
||||
(1024**2))
|
||||
# delete unused vars not in chunk_input_list
|
||||
# we can't delete input nodes until chunk ends
|
||||
if chunk_within:
|
||||
act_memory -= self._get_chunk_delete_node_size(
|
||||
node,
|
||||
user_to_last_uses_no_free_var,
|
||||
chunk_ratio,
|
||||
chunk_inputs_names,
|
||||
) / (1024**2)
|
||||
else:
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
|
||||
chunk_inputs_names) / (1024**2)
|
||||
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||
|
||||
# if node in chunk end nodes, restore chunk settings
|
||||
if use_chunk and idx in chunk_ends:
|
||||
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
|
||||
act_memory -= self._get_chunk_inputs_size(
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_non_chunk[chunk_region_idx],
|
||||
node_list,
|
||||
chunk_regions[chunk_region_idx][1],
|
||||
) / (1024**2)
|
||||
chunk_within = False
|
||||
chunk_ratio = 1
|
||||
chunk_region_idx = None
|
||||
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
|
||||
if print_mem:
|
||||
print("with chunk" if use_chunk else "without chunk")
|
||||
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_compute_op_mem_log(
|
||||
# act_memory_after_node_log, node_list, "after"
|
||||
# )
|
||||
|
||||
# param_memory = parameter_size(gm)
|
||||
# all_memory = act_memory + param_memory
|
||||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
||||
|
||||
def get_active_nodes(self, node_list: List) -> List:
|
||||
"""
|
||||
Get active nodes for every node
|
||||
|
||||
Args:
|
||||
node_list (List): _description_
|
||||
|
||||
Returns:
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
nodes generated but not deleted.
|
||||
"""
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
for _, node in enumerate(node_list):
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
return active_node_list_log
|
|
@ -0,0 +1,117 @@
|
|||
from .trace_indice import TraceIndice
|
||||
from .utils import find_idx_by_name
|
||||
|
||||
|
||||
class ReorderGraph(object):
|
||||
"""
|
||||
Reorder node list and indice trace list
|
||||
"""
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.all_reorder_map = {
|
||||
i: i for i in range(len(self.trace_indice.indice_trace_list))
|
||||
}
|
||||
|
||||
def _get_reorder_map(self, chunk_info):
|
||||
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
|
||||
|
||||
chunk_region_start = chunk_info["region"][0]
|
||||
chunk_region_end = chunk_info["region"][1]
|
||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||
chunk_prepose_nodes_idx = [
|
||||
find_idx_by_name(i.name, self.trace_indice.node_list)
|
||||
for i in chunk_prepose_nodes
|
||||
]
|
||||
# put prepose nodes ahead
|
||||
for idx, n in enumerate(chunk_prepose_nodes):
|
||||
n_idx = chunk_prepose_nodes_idx[idx]
|
||||
reorder_map[n_idx] = chunk_region_start + idx
|
||||
# put other nodes after prepose nodes
|
||||
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||
if n in chunk_prepose_nodes:
|
||||
continue
|
||||
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
|
||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||
reorder_map[n_idx] = n_idx + pos
|
||||
|
||||
return reorder_map
|
||||
|
||||
def _reorder_chunk_info(self, chunk_info, reorder_map):
|
||||
# update chunk info
|
||||
chunk_info["region"] = (
|
||||
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
|
||||
chunk_info["region"][1],
|
||||
)
|
||||
new_inputs_dim = []
|
||||
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||
new_input_dim = {}
|
||||
for k, v in input_dim.items():
|
||||
new_input_dim[reorder_map[k]] = v
|
||||
new_inputs_dim.append(new_input_dim)
|
||||
chunk_info["inputs_dim"] = new_inputs_dim
|
||||
return chunk_info
|
||||
|
||||
def _update_all_reorder_map(self, reorder_map):
|
||||
for origin_idx, map_idx in self.all_reorder_map.items():
|
||||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||
|
||||
def _reorder_self_node_list(self, reorder_map):
|
||||
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
|
||||
self.trace_indice.node_list = new_node_list
|
||||
|
||||
def _reorder_idx_trace(self, reorder_map):
|
||||
# reorder list
|
||||
new_idx_trace_list = [
|
||||
None for _ in range(len(self.trace_indice.indice_trace_list))
|
||||
]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
||||
self.trace_indice.indice_trace_list = new_idx_trace_list
|
||||
# update compute
|
||||
for idx_trace in self.trace_indice.indice_trace_list:
|
||||
compute = idx_trace["compute"]
|
||||
for dim_compute in compute:
|
||||
for idx, i in enumerate(dim_compute):
|
||||
dim_compute[idx] = reorder_map[i]
|
||||
# update source
|
||||
for idx_trace in self.trace_indice.indice_trace_list:
|
||||
source = idx_trace["source"]
|
||||
for dim_idx, dim_source in enumerate(source):
|
||||
new_dim_source = {}
|
||||
for k, v in dim_source.items():
|
||||
new_dim_source[reorder_map[k]] = v
|
||||
source[dim_idx] = new_dim_source
|
||||
|
||||
def reorder_all(self, chunk_info):
|
||||
if chunk_info is None:
|
||||
return chunk_info
|
||||
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||
return chunk_info
|
||||
reorder_map = self._get_reorder_map(chunk_info)
|
||||
self._update_all_reorder_map(reorder_map)
|
||||
self._reorder_idx_trace(reorder_map)
|
||||
self._reorder_self_node_list(reorder_map)
|
||||
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||
return chunk_info
|
||||
|
||||
def reorder_node_list(self, node_list):
|
||||
new_node_list = [None for _ in range(len(node_list))]
|
||||
for old_idx, new_idx in self.all_reorder_map.items():
|
||||
new_node_list[new_idx] = node_list[old_idx]
|
||||
return new_node_list
|
||||
|
||||
def tmp_reorder(self, node_list, chunk_info):
|
||||
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||
return node_list, chunk_info
|
||||
reorder_map = self._get_reorder_map(chunk_info)
|
||||
|
||||
# new tmp node list
|
||||
new_node_list = [None for _ in range(len(node_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_node_list[new_idx] = node_list[old_idx]
|
||||
|
||||
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||
return new_node_list, chunk_info
|
|
@ -0,0 +1,319 @@
|
|||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
"""
|
||||
This is the core class for AutoChunk.
|
||||
|
||||
It defines the framework of the strategy of AutoChunk.
|
||||
Chunks will be selected one by one utill search stops.
|
||||
|
||||
The chunk search is as follows:
|
||||
1. find the peak memory node
|
||||
2. find the max chunk region according to the peak memory node
|
||||
3. find all possible chunk regions in the max chunk region
|
||||
4. find the best chunk region for current status
|
||||
5. goto 1
|
||||
|
||||
Attributes:
|
||||
gm: graph model
|
||||
print_mem (bool): print estimated memory
|
||||
trace_index: trace the flow of every dim of every node to find all free dims
|
||||
trace_flow: determine the region chunk strategy
|
||||
reorder_graph: reorder nodes to improve chunk efficiency
|
||||
estimate_memory: estimate memory with chunk
|
||||
select_chunk: select the best chunk region
|
||||
|
||||
Args:
|
||||
gm: graph model
|
||||
max_memory (int): max memory in MB
|
||||
print_mem (bool): print estimated memory
|
||||
"""
|
||||
|
||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.print_progress = print_progress
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
self.reorder_graph,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
def _init_trace(self) -> None:
|
||||
"""
|
||||
find the max trace range for every node
|
||||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||
cur_node_idx = max_chunk_region[1]
|
||||
if cur_node_idx == len(active_nodes) - 1:
|
||||
break
|
||||
max_chunk_region_list.append(max_chunk_region)
|
||||
|
||||
# nothing to limit for the first range
|
||||
max_chunk_region_list = max_chunk_region_list[1:]
|
||||
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
|
||||
|
||||
# set trace range and do the trace
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start tracing indice")
|
||||
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
|
||||
self.trace_indice.trace_indice()
|
||||
|
||||
def _find_peak_node(self, mem_peak: List) -> int:
|
||||
max_value = max(mem_peak)
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
|
||||
def _get_free_var_idx(self) -> List:
|
||||
"""
|
||||
Get free var index
|
||||
|
||||
Returns:
|
||||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
Chunk region starts extending from the peak node, stops where free var num is min
|
||||
|
||||
Args:
|
||||
active_node (List): active node status for every node
|
||||
peak_node_idx (int): peak memory node idx
|
||||
chunk_regions (List): chunk region infos
|
||||
|
||||
Returns:
|
||||
chunk_region_start (int)
|
||||
chunk_region_end (int)
|
||||
"""
|
||||
free_vars = self._get_free_var_idx()
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node_idx, -1, -1):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node_idx, len(active_node)):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
# avoid chunk regions overlap
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
if len(start_traces) > 1:
|
||||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_duplicate(chunk_info):
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
Args:
|
||||
max_chunk_region (Tuple)
|
||||
peak_node (Node): peak memory node
|
||||
|
||||
Returns:
|
||||
possible_chunk_region (List)
|
||||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
||||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
||||
def _step_search(
|
||||
self,
|
||||
mem_peak: List[float],
|
||||
active_node: List[List[Node]],
|
||||
chunk_infos: List[Dict],
|
||||
) -> Dict:
|
||||
"""
|
||||
Find one chunk region
|
||||
|
||||
The chunk search is as follows:
|
||||
1. find the peak memory node
|
||||
2. find the max chunk region according to the peak memory node
|
||||
3. find all possible chunk regions in the max chunk region
|
||||
4. find the best chunk region for current status
|
||||
|
||||
Args:
|
||||
mem_peak (List): peak memory for every node
|
||||
active_node (List[List[Node]]): active node for every node
|
||||
chunk_infos (List[Dict]): all chunk info
|
||||
|
||||
Returns:
|
||||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
|
||||
max_chunk_region, mem_peak)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
sorted_init_mem_peak = sorted(init_mem_peak)
|
||||
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def search_region(self) -> Dict:
|
||||
"""
|
||||
Search all chunk regions:
|
||||
1. Estimate current memory
|
||||
2. Find best chunk for current memory
|
||||
3. goto 1
|
||||
|
||||
Returns:
|
||||
chunk_infos (Dict)
|
||||
"""
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start searching chunk regions")
|
||||
|
||||
chunk_infos = []
|
||||
(
|
||||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
|
||||
if chunk_info is None:
|
||||
break
|
||||
chunk_infos.append(chunk_info)
|
||||
|
||||
(
|
||||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
||||
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
||||
return chunk_infos
|
|
@ -0,0 +1,224 @@
|
|||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
def __init__(
|
||||
self,
|
||||
trace_indice: TraceIndice,
|
||||
estimate_memory: EstimateMemory,
|
||||
reorder_graph: ReorderGraph,
|
||||
max_memory=None,
|
||||
):
|
||||
self.trace_indice = trace_indice
|
||||
self.estimate_memory = estimate_memory
|
||||
self.reorder_graph = reorder_graph
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
def _select_best_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
if self.stratge == "min_memory":
|
||||
best_region = self._select_min_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
elif self.stratge == "fit_memory":
|
||||
best_region = self._select_fit_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError()
|
||||
return best_region
|
||||
|
||||
def _select_fit_memory_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
# stop chunk if max memory satisfy memory limit
|
||||
if max(mem_peak) < self.max_memory:
|
||||
return None
|
||||
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
if not self._is_legal_region(i, chunk_infos):
|
||||
illegal_regions.append(i)
|
||||
for i in illegal_regions:
|
||||
if i in possible_chunk_regions:
|
||||
possible_chunk_regions.remove(i)
|
||||
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||
]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
# no region found
|
||||
if len(regions_dict) == 0:
|
||||
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||
|
||||
# select the min chunk len
|
||||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||
best_region_idx = chunk_len.index(min(chunk_len))
|
||||
best_region = regions_dict[best_region_idx]
|
||||
|
||||
# get max chunk size
|
||||
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
||||
return best_region
|
||||
|
||||
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
||||
chunk_size = 1
|
||||
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_max_mem = 0
|
||||
# search a region
|
||||
while cur_chunk_max_mem < self.max_memory:
|
||||
chunk_size *= 2
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[
|
||||
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
|
||||
+ 1
|
||||
]
|
||||
)
|
||||
# search exact size
|
||||
chunk_info = chunk_region_dict["chunk_info"]
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||
)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||
if left >= 16:
|
||||
gap = 4
|
||||
else:
|
||||
gap = 1
|
||||
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||
while right >= left + gap:
|
||||
mid = int((left + right) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
)
|
||||
if cur_chunk_max_mem >= self.max_memory:
|
||||
right = mid - gap
|
||||
else:
|
||||
left = mid + gap
|
||||
return left
|
||||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.trace_indice.node_list[start : end + 1]:
|
||||
if not is_non_compute_node(i):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _select_min_memory_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
if not self._is_legal_region(i, chunk_infos):
|
||||
illegal_regions.append(i)
|
||||
for i in illegal_regions:
|
||||
if i in possible_chunk_regions:
|
||||
possible_chunk_regions.remove(i)
|
||||
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||
]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
|
||||
# select the min mem
|
||||
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
|
||||
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||
if best_region is not None:
|
||||
best_region["chunk_size"] = 1
|
||||
return best_region
|
||||
|
||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||
if cur_chunk_info in chunk_infos:
|
||||
return False
|
||||
if chunk_region_end < chunk_region_start:
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
|
@ -0,0 +1,445 @@
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
|
||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||
"""
|
||||
Check 2 given index: one index should be source of the other
|
||||
Args:
|
||||
start_idx(int): start node chunk dim
|
||||
start_node(node): start node
|
||||
end_idx(int): end node chunk dim
|
||||
end_node(node): end node
|
||||
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
||||
for node_idx, node_dim in sorted_source:
|
||||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
# it means we meet a node outside the loop, and the node is not input node
|
||||
if node_idx < start_idx:
|
||||
return False
|
||||
return False
|
||||
|
||||
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
||||
"""
|
||||
Check 2 given index: check they haven't been computed in the source trace.
|
||||
Args:
|
||||
start_idx(int): start node chunk dim
|
||||
start_node(node): start node
|
||||
end_idx(int): end node chunk dim
|
||||
end_node(node): end node
|
||||
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_compute = end_node_trace["compute"][end_dim]
|
||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
duplicate_dims = []
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
duplicate_dim = []
|
||||
duplicate_flag = False
|
||||
dim_source = node_trace_source[node_dim]
|
||||
for k, v in dim_source.items():
|
||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||
duplicate_flag = True
|
||||
duplicate_dim.append((k, v))
|
||||
duplicate_dims.append(duplicate_dim)
|
||||
if duplicate_flag:
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
if return_dim:
|
||||
return False, duplicate_dims
|
||||
else:
|
||||
return False
|
||||
if return_dim:
|
||||
return True, None
|
||||
else:
|
||||
return True
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node: Node,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
cur_node_dim: int,
|
||||
cur_node_compute: Dict,
|
||||
cur_node_source: Dict,
|
||||
cur_node_fix_dim: List,
|
||||
all_node_info: Dict,
|
||||
next_node_list: List,
|
||||
) -> bool:
|
||||
"""
|
||||
Given the current node and one of its arg node,
|
||||
this function finds out arg node's chunk dim and fix dim
|
||||
|
||||
Args:
|
||||
arg_node (Node): input node
|
||||
start_idx (int): chunk region start
|
||||
end_idx (int): chunk region end
|
||||
cur_node_dim (int): current node chunk dim
|
||||
cur_node_compute (Dict): current node compute dict
|
||||
cur_node_source (Dict): current node source dict
|
||||
cur_node_fix_dim (List): current node fix dim
|
||||
all_node_info (Dict): all node chunk info in the chunk region
|
||||
next_node_list (List)
|
||||
|
||||
Returns:
|
||||
bool: True if this node can be added to the flow, vice versa.
|
||||
"""
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
||||
# find arg dim
|
||||
if cur_node_dim is not None:
|
||||
# dim is computed
|
||||
if arg_idx in cur_node_compute[cur_node_dim]:
|
||||
return False
|
||||
if arg_idx not in cur_node_source[cur_node_dim]:
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
# chunk dim should be None if shape size is 1
|
||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
# get fix dim
|
||||
arg_fix_dim = []
|
||||
if cur_node_dim is not None:
|
||||
for i in cur_node_fix_dim:
|
||||
fix_dim_source = cur_node_source[i]
|
||||
if arg_idx in fix_dim_source:
|
||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||
|
||||
# if already in node_info, arg dim must be same
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||
|
||||
next_node_list.append(arg_node)
|
||||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
next_node_list = []
|
||||
|
||||
for cur_node in cur_node_list:
|
||||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim is not None:
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
# get all valid args
|
||||
arg_list = []
|
||||
for arg in cur_node.all_input_nodes:
|
||||
if type(arg) != type(cur_node):
|
||||
continue
|
||||
if is_non_compute_node(arg):
|
||||
continue
|
||||
arg_list.append(arg)
|
||||
flow_flag = self._assgin_single_node_flow(
|
||||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node_chunk_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
)
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
|
||||
for arg in arg_list:
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
arg_shape = get_node_shape(arg)
|
||||
# add all dim as fix dim except chunk dim
|
||||
for i, shape in enumerate(arg_shape):
|
||||
if shape != 1 and i != cur_node_chunk_dim:
|
||||
if i == arg_chunk_dim:
|
||||
return None
|
||||
if i not in arg_fix_dim:
|
||||
arg_fix_dim.append(i)
|
||||
elif "einsum" in cur_node.name:
|
||||
pass
|
||||
elif "matmul" in cur_node.name:
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
cur_node_list = next_node_list
|
||||
return all_node_info
|
||||
|
||||
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
|
||||
"""
|
||||
Get chunk dim for every input node for their every entry, remove unchunked nodes
|
||||
|
||||
Args:
|
||||
inputs (List[Node]): input nodes
|
||||
all_node_info (Dict): describe all node's chunk dim and fix dim
|
||||
start_idx (int): chunk start idx
|
||||
end_idx (int): chunk end idx
|
||||
|
||||
Returns:
|
||||
inputs (List(Node)): new inputs
|
||||
inputs_dim (List): chunk dim for inputs
|
||||
"""
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
for user in input_node.users.keys():
|
||||
# skip non compute
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
# untraced node, mostly non compute
|
||||
if user not in all_node_info:
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
|
||||
input_dict[user_idx] = [None]
|
||||
else:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
else:
|
||||
return None, None
|
||||
if len(input_dict) == 0:
|
||||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
# remove unchunked inputs
|
||||
for i in remove_inputs:
|
||||
if i in inputs:
|
||||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
|
||||
"""
|
||||
get all useless nodes in chunk region and prepose them
|
||||
|
||||
Args:
|
||||
all_node_info (Dict): describe all node's chunk dim and fix dim
|
||||
start_idx (int): chunk start idx
|
||||
end_idx (int): chunk end idx
|
||||
|
||||
Returns:
|
||||
List[Node]: all nodes to be preposed
|
||||
"""
|
||||
# get all possible prepose nodes
|
||||
maybe_prepose_nodes = []
|
||||
for node, node_info in all_node_info.items():
|
||||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
|
||||
tmp_cur_related_prepose_nodes = []
|
||||
prepose_flag = True
|
||||
|
||||
# loop cur node's all arg until out of chunk
|
||||
while len(tmp_cur_prepose_nodes) > 0:
|
||||
if prepose_flag == False:
|
||||
break
|
||||
tmp_next_prepose_nodes = []
|
||||
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
|
||||
for cur_prepose_node in tmp_cur_prepose_nodes:
|
||||
if prepose_flag == False:
|
||||
break
|
||||
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
|
||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
|
||||
end_idx):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
|
||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||
else:
|
||||
prepose_flag = False
|
||||
break
|
||||
# non compute op
|
||||
else:
|
||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
|
||||
|
||||
if prepose_flag == False:
|
||||
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
|
||||
continue
|
||||
else:
|
||||
for n in tmp_cur_related_prepose_nodes:
|
||||
if n not in prepose_nodes:
|
||||
prepose_nodes.append(n)
|
||||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
|
||||
|
||||
return prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info["inputs"]:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
return chunk_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"node_chunk_dim": all_node_info,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
|
||||
|
||||
# find non chunk inputs
|
||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||
|
||||
# reassgin reshape size, some size may have changed due to chunk
|
||||
chunk_info = self._reassgin_reshape_size(chunk_info)
|
||||
|
||||
return chunk_info
|
||||
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
"""
|
||||
Some shape args in reshape may have changed due to chunk
|
||||
reassgin those changed shape
|
||||
"""
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
|
||||
else:
|
||||
if isinstance(reshape_arg, int):
|
||||
new_shape += "%s, " % str(reshape_arg)
|
||||
else:
|
||||
new_shape += "%s, " % reshape_arg.name
|
||||
new_shape = new_shape[:-2]
|
||||
origin_shape = str(reshape_args)[1:-1]
|
||||
reshape_size[node.name] = [origin_shape, new_shape]
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
|
@ -0,0 +1,703 @@
|
|||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
"""
|
||||
Trace all indice infomation for every node.
|
||||
|
||||
Indice is a logical concept. Equal dims can been treated as one indice.
|
||||
eg. dim(x1) = [a, b, c]
|
||||
dim(x2) = [d, e, f]
|
||||
and we have x3 = x1 * x2.
|
||||
then a=d, b=e, c=f, due to the broadcast property,
|
||||
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
|
||||
This class will record every node's dims' indice, compute and source.
|
||||
|
||||
Attibutes:
|
||||
node_list (List)
|
||||
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
|
||||
indice_view_list (Dict): not used for now
|
||||
indice_count (int): record indice number
|
||||
|
||||
Args:
|
||||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List[Node]) -> None:
|
||||
self.node_list = node_list
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
self.trace_range = []
|
||||
self.active_node_list = []
|
||||
|
||||
def _init_indice_trace_list(self):
|
||||
indice_trace_list = []
|
||||
for n in self.node_list:
|
||||
if get_node_shape(n) != None:
|
||||
cur_trace = {
|
||||
"indice": [None for _ in range(len(get_node_shape(n)))],
|
||||
"compute": [[] for _ in range(len(get_node_shape(n)))],
|
||||
"source": [{} for _ in range(len(get_node_shape(n)))],
|
||||
}
|
||||
else:
|
||||
cur_trace = {"indice": [], "compute": [], "source": []}
|
||||
indice_trace_list.append(cur_trace)
|
||||
return indice_trace_list
|
||||
|
||||
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
|
||||
self.trace_range = trace_range
|
||||
self.active_node_list = active_node_list
|
||||
|
||||
def _add_indice(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
||||
Returns:
|
||||
indice_count: int
|
||||
"""
|
||||
self.indice_count += 1
|
||||
return self.indice_count
|
||||
|
||||
def _del_dim(self, idx, dim_idx):
|
||||
self.indice_trace_list[idx]["indice"].pop(dim_idx)
|
||||
self.indice_trace_list[idx]["compute"].pop(dim_idx)
|
||||
self.indice_trace_list[idx]["source"].pop(dim_idx)
|
||||
|
||||
def _add_dim(self, node_idx, dim_idx):
|
||||
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
|
||||
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||
|
||||
def _transform_indice(self, node, node_dim):
|
||||
node_idx = self._find_indice_trace_from_node(node)
|
||||
dims = list(range(len(node_idx)))
|
||||
return dims[node_dim]
|
||||
|
||||
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(len(node_from_compute)):
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
|
||||
if init:
|
||||
node_to_trace_source[node_to_dim] = {}
|
||||
# add dim to cur new source
|
||||
if node_from_idx not in node_to_trace_source[node_to_dim]:
|
||||
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
||||
else:
|
||||
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
||||
node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
|
||||
# update inputs source
|
||||
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
||||
if node_idx not in node_to_trace_source[node_to_dim]:
|
||||
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
|
||||
else:
|
||||
for d in node_dim:
|
||||
if d not in node_to_trace_source[node_to_dim][node_idx]:
|
||||
node_to_trace_source[node_to_dim][node_idx].append(d)
|
||||
|
||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||
if exclude == None:
|
||||
exclude = []
|
||||
else:
|
||||
exclude = [self._transform_indice(node_to, i) for i in exclude]
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
# assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
|
||||
if self._transform_indice(node_to, i) in exclude:
|
||||
continue
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
for j in node_from_compute[i]:
|
||||
if j not in node_to_compute[i]:
|
||||
node_to_compute[i].append(j)
|
||||
|
||||
def _mark_computation(self, node, idx, dim):
|
||||
"""
|
||||
Mark some dims of node as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
idx (int): node index
|
||||
dim (list or int): dims to be marked as computed
|
||||
"""
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
dims = list(range(len(get_node_shape(node))))
|
||||
for d in dim:
|
||||
cur_dim = dims[d]
|
||||
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
|
||||
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
|
||||
|
||||
def _find_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx and compute trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
def _find_source_trace_from_node(self, node):
|
||||
"""
|
||||
Find node source trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
def _find_indice_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["indice"]
|
||||
|
||||
def _find_compute_trace_from_node(self, node):
|
||||
"""
|
||||
Find node compute trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if input_node == None:
|
||||
input_node = find_first_tensor_arg(node)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
||||
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
|
||||
|
||||
self._inherit_all_computation(input_node, node)
|
||||
|
||||
def _assign_all_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Add new indice for all node's dims.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
shape = node.meta["tensor_meta"].shape
|
||||
if shape is None:
|
||||
return
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
new_trace.append(self._add_indice())
|
||||
self.indice_trace_list[node_idx]["indice"] = new_trace
|
||||
|
||||
def _assign_transpose_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
2. inherit input's computation
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node = node.args[0]
|
||||
tranpose_dim = node.args[1:]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||
|
||||
def _assign_permute_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
2. inherit input's computation
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = flat_list(node.args[1:])
|
||||
input_node = node.args[0]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
self._inherit_indice(input_node, d, node, idx)
|
||||
|
||||
def _assign_linear_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for linear op.
|
||||
1. copy trace from input node and change last indice accroding to weight
|
||||
2. mark equal for input node last indice, weight first dim and bias dim.
|
||||
3. inherit input's computation, mark computation for last dim.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if len(node.args) == 2:
|
||||
_, weight = node.args
|
||||
else:
|
||||
_, weight, _ = node.args
|
||||
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_matmul_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for matmul op.
|
||||
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
|
||||
2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
|
||||
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
matmul_left, matmul_right = node.args
|
||||
|
||||
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||
self._inherit_indice(matmul_right, -1, node, -1)
|
||||
|
||||
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_layernorm_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for layernorm op.
|
||||
1. assign indice as input node
|
||||
2. inherit computation and mark last 2 dims as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [-1])
|
||||
|
||||
def _assign_elementwise_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for element-wise op (eg. relu sigmoid add mul).
|
||||
1. assign indice as input node
|
||||
2. inherit computation from all input nodes.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, idx)
|
||||
nodes_in = []
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
nodes_in.append(node_in)
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
assert len(nodes_in) <= 2
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
|
||||
def _assign_einsum_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for einsum op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
patterns = node.args[0]
|
||||
input_nodes = node.args[1:]
|
||||
|
||||
patterns = patterns.replace(" ", "")
|
||||
left, right = patterns.split("->")
|
||||
left = left.split(",")
|
||||
|
||||
if '...' in right:
|
||||
replace_list = "!@#$%^&*"
|
||||
target_len = len(get_node_shape(node))
|
||||
add_len = target_len - len(right) + 3
|
||||
replace_str = replace_list[:add_len]
|
||||
right = right.replace("...", replace_str)
|
||||
for ll in range(len(left)):
|
||||
left[ll] = left[ll].replace("...", replace_str)
|
||||
|
||||
all_index = []
|
||||
for i in left:
|
||||
for c in i:
|
||||
all_index.append(c)
|
||||
all_index = set(all_index)
|
||||
|
||||
for right_idx, right_indice in enumerate(right):
|
||||
for left_idx, left_str in enumerate(left):
|
||||
if right_indice in left_str:
|
||||
source_idx = left_str.index(right_indice)
|
||||
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
|
||||
|
||||
def _assign_softmax_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for softmax op.
|
||||
1. assign indice as input node
|
||||
2. inherit computation and mark softmax dim as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
||||
|
||||
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
dim_idx = node.args[1]
|
||||
# unsqueeze(-1) = unsqueeze(shape_num + 1)
|
||||
if dim_idx < 0:
|
||||
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
||||
self._add_dim(node_idx, dim_idx)
|
||||
|
||||
def _assign_dropout_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
def _assign_ones_like_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for oneslike op.
|
||||
1. assign new indice for all dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_cat_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for cat op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
nodes_in = flat_list(node.args[0])
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._mark_computation_from_node(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
self._add_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_sum_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for sum op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
nodes_in = flat_list(node.args[0])
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._mark_computation_from_node(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for getitem.
|
||||
getitem can act like slice sometimes
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
node_args = flat_list(node.args[1:])
|
||||
flag = False
|
||||
for node_arg in node_args:
|
||||
node_arg_str = str(node_arg)
|
||||
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
|
||||
flag = True
|
||||
break
|
||||
if "slice" in node_arg_str:
|
||||
flag = True
|
||||
break
|
||||
if flag == False:
|
||||
return
|
||||
|
||||
# node args should be like [Ellipsis, slice(start, step, end), None]
|
||||
node_shape = get_node_shape(node)
|
||||
origin_idx_count = 0
|
||||
new_idx_count = 0
|
||||
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
|
||||
for _ in range(new_dim_num):
|
||||
self._del_dim(node_idx, 0)
|
||||
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
|
||||
for _ in range(delete_dim_num):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
for _, node_arg in enumerate(node_args):
|
||||
node_arg_str = str(node_arg)
|
||||
# Ellipsis means [..., ]
|
||||
if "Ellipsis" == node_arg_str:
|
||||
shape_gap = len(node_shape) - len(node_args) + 1
|
||||
origin_idx_count += shape_gap
|
||||
new_idx_count += shape_gap
|
||||
# slice(None, None, None) means all indexes
|
||||
elif "slice" in node_arg_str:
|
||||
if "slice(None, None, None)" != node_arg_str:
|
||||
self._del_dim(node_idx, new_idx_count)
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
origin_idx_count += 1
|
||||
new_idx_count += 1
|
||||
# None means a new dim
|
||||
elif "None" == node_arg_str:
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
new_idx_count += 1
|
||||
elif "0" == node_arg_str:
|
||||
self._del_dim(node_idx, new_idx_count)
|
||||
origin_idx_count += 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
2. compute the real value of -1 in target shape.
|
||||
3. determine changed dim, and assgin indice for generated dim.
|
||||
4. log changed dim and generated dim for restore
|
||||
5. inherit computation.
|
||||
6. TODO: look into view list to see whether the view is associated with other,
|
||||
if so assgin equal dim according to previous view.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
# get data, turn into number
|
||||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||
target_shape = []
|
||||
unflated_args = flat_list(node.args)
|
||||
for i in range(1, len(unflated_args)):
|
||||
if isinstance(unflated_args[i], int):
|
||||
target_shape.append(unflated_args[i])
|
||||
else:
|
||||
target_shape.append(unflated_args[i].meta["fwd_out"][0])
|
||||
|
||||
# compute the value of -1
|
||||
if -1 in target_shape:
|
||||
origin_product = 1
|
||||
for i in origin_shape:
|
||||
origin_product *= i
|
||||
target_product = -1
|
||||
for i in target_shape:
|
||||
target_product *= i
|
||||
shape_idx = target_shape.index(-1)
|
||||
target_shape[shape_idx] = origin_product // target_product
|
||||
|
||||
# determine changed dim
|
||||
len_diff = len(origin_shape) - len(target_shape)
|
||||
if len_diff == 1:
|
||||
# dim merge
|
||||
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
||||
dim_to = [dim_equal.index(False)]
|
||||
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._add_dim(node_idx, -1)
|
||||
elif len_diff == -1:
|
||||
# dim expand
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = [dim_equal.index(False)]
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
else:
|
||||
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new indice
|
||||
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||
self._assign_indice_as_input(node, node_idx, origin_node)
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
self._del_dim(node_idx, i)
|
||||
for i in dim_to:
|
||||
self._add_dim(node_idx, i)
|
||||
|
||||
# inherit computation
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
for j in dim_to:
|
||||
self._mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
"idx_from": [origin_trace[i] for i in dim_from],
|
||||
"dim_from": dim_from,
|
||||
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
|
||||
"dim_to": dim_to,
|
||||
}
|
||||
self.indice_view_list[node] = view_dict
|
||||
|
||||
def _clear_trace(self, node_idx: int) -> None:
|
||||
"""
|
||||
clear too far trace to speed up computation
|
||||
"""
|
||||
trace_range = None
|
||||
for i in range(len(self.trace_range)):
|
||||
if self.trace_range[i][1] == node_idx:
|
||||
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
|
||||
break
|
||||
if self.trace_range[i][1] > node_idx:
|
||||
break
|
||||
if trace_range is None:
|
||||
return
|
||||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_range[0] and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self):
|
||||
for idx, node in enumerate(self.node_list):
|
||||
if node.op == "placeholder":
|
||||
self._assign_all_indice(node, idx)
|
||||
elif node.op == "call_method":
|
||||
if "transpose" in node.name:
|
||||
self._assign_transpose_indice(node, idx)
|
||||
elif "permute" in node.name:
|
||||
self._assign_permute_indice(node, idx)
|
||||
elif "view" in node.name or "reshape" in node.name:
|
||||
self._assign_view_reshape_indice(node, idx)
|
||||
elif "unsqueeze" in node.name:
|
||||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == "call_function":
|
||||
if "linear" in node.name:
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "cat" in node.name:
|
||||
self._assign_cat_indice(node, idx)
|
||||
elif "matmul" in node.name:
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" in node.name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
elif "dropout" in node.name:
|
||||
self._assign_dropout_indice(node, idx)
|
||||
elif "einsum" in node.name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "sum" in node.name:
|
||||
self._assign_sum_indice(node, idx)
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "getitem" in node.name:
|
||||
self._assign_getitem_indice(node, idx)
|
||||
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == "call_module":
|
||||
if any(n in node.name for n in ["layernorm", "norm"]):
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == "get_attr":
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
elif node.op == "output":
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
# limit trace range
|
||||
self._clear_trace(idx)
|
|
@ -0,0 +1,132 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def get_logger():
|
||||
return logger
|
||||
|
||||
|
||||
def flat_list(inputs: Any) -> List:
|
||||
"""
|
||||
flat a list by recursion
|
||||
"""
|
||||
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
|
||||
return [inputs]
|
||||
res = []
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
res.extend(flat_list(i))
|
||||
else:
|
||||
res.append(i)
|
||||
return res
|
||||
|
||||
|
||||
def find_first_tensor_arg(node: Node) -> Node:
|
||||
"""
|
||||
Find the first input tensor arg for a node
|
||||
"""
|
||||
for arg in node.args:
|
||||
if type(arg) == type(node):
|
||||
return arg
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
def is_non_compute_node(node: Node) -> bool:
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
node_args = flat_list(node.args[1:])
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
return False
|
||||
if "slice" in str(node_arg):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_node_shape(node: Node) -> List:
|
||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||
return node.meta["tensor_meta"].shape
|
||||
return None
|
||||
|
||||
|
||||
def is_non_memory_node(node: Node) -> bool:
|
||||
if "getitem" in node.name:
|
||||
return True
|
||||
if "output" in node.op:
|
||||
return True
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
if "output" in node.op:
|
||||
return False
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name, nodes_list):
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def delete_free_var_from_last_use(user_to_last_uses):
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == "placeholder":
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
output nodes are nodes will use nodes in the list
|
||||
"""
|
||||
input_nodes = []
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if input_node not in nodes and input_node not in input_nodes:
|
||||
input_nodes.append(input_node)
|
||||
return input_nodes
|
||||
|
||||
|
||||
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
output nodes are nodes will use nodes in the list
|
||||
"""
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
|
||||
# if a node has an input node which is not in the node list
|
||||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (input_node not in nodes and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)):
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if (output_node not in nodes and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List
|
||||
import socket
|
||||
from typing import List
|
||||
|
||||
|
||||
class HostInfo:
|
||||
|
@ -35,9 +35,14 @@ class HostInfo:
|
|||
|
||||
if port is None:
|
||||
port = 22 # no port specified, lets just use the ssh port
|
||||
hostname = socket.getfqdn(hostname)
|
||||
|
||||
# socket.getfqdn("127.0.0.1") does not return localhost
|
||||
# on some users' machines
|
||||
# thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0
|
||||
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
|
||||
return True
|
||||
|
||||
hostname = socket.getfqdn(hostname)
|
||||
localhost = socket.gethostname()
|
||||
localaddrs = socket.getaddrinfo(localhost, port)
|
||||
targetaddrs = socket.getaddrinfo(hostname, port)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import fabric
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing import connection as mp_connection
|
||||
|
||||
import click
|
||||
import fabric
|
||||
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
|
||||
|
||||
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
|
||||
|
@ -45,8 +47,10 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
|
|||
# execute on the remote machine
|
||||
fab_conn.run(cmds, hide=False)
|
||||
send_conn.send('success')
|
||||
except:
|
||||
click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
|
||||
)
|
||||
send_conn.send('failure')
|
||||
|
||||
# shutdown
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
import click
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from colossalai.context import Config
|
||||
from .multinode_runner import MultiNodeRunner
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
from .multinode_runner import MultiNodeRunner
|
||||
|
||||
# Constants that define our syntax
|
||||
NODE_SEP = ','
|
||||
|
||||
|
@ -15,7 +18,7 @@ NODE_SEP = ','
|
|||
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
"""
|
||||
Parse the hostfile to obtain a list of hosts.
|
||||
|
||||
|
||||
A hostfile should look like:
|
||||
worker-0
|
||||
worker-1
|
||||
|
@ -63,7 +66,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
|
|||
device_pool (HostInfoList): a list of HostInfo objects
|
||||
include_str (str): --include option passed by user, default None
|
||||
exclude_str (str): --exclude option passed by user, default None
|
||||
|
||||
|
||||
Returns:
|
||||
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
||||
'''
|
||||
|
@ -192,7 +195,7 @@ def launch_multi_processes(args: Config) -> None:
|
|||
Launch multiple processes on a single node or multiple nodes.
|
||||
|
||||
The overall logic can be summarized as the pseudo code below:
|
||||
|
||||
|
||||
if hostfile given:
|
||||
hostinfo = parse_hostfile(hostfile)
|
||||
hostinfo = include_or_exclude_hosts(hostinfo)
|
||||
|
@ -202,7 +205,7 @@ def launch_multi_processes(args: Config) -> None:
|
|||
launch_on_multi_nodes(hostinfo)
|
||||
else:
|
||||
launch_on_current_node()
|
||||
|
||||
|
||||
Args:
|
||||
args (Config): the arguments taken from command line
|
||||
|
||||
|
@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
|
|||
extra_launch_args=args.extra_launch_args)
|
||||
runner.send(hostinfo=hostinfo, cmd=cmd)
|
||||
|
||||
runner.recv_from_all()
|
||||
# start training
|
||||
msg_from_node = runner.recv_from_all()
|
||||
has_error = False
|
||||
|
||||
# print node status
|
||||
click.echo("\n====== Training on All Nodes =====")
|
||||
for hostname, msg in msg_from_node.items():
|
||||
click.echo(f"{hostname}: {msg}")
|
||||
|
||||
# check if a process failed
|
||||
if msg == "failure":
|
||||
has_error = True
|
||||
|
||||
# stop all nodes
|
||||
runner.stop_all()
|
||||
runner.recv_from_all()
|
||||
|
||||
# receive the stop status
|
||||
msg_from_node = runner.recv_from_all()
|
||||
|
||||
# printe node status
|
||||
click.echo("\n====== Stopping All Nodes =====")
|
||||
for hostname, msg in msg_from_node.items():
|
||||
click.echo(f"{hostname}: {msg}")
|
||||
|
||||
# give the process an exit code
|
||||
# so that it behaves like a normal process
|
||||
if has_error:
|
||||
sys.exit(1)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
|
|
@ -381,6 +381,8 @@ class AlphaBetaProfiler:
|
|||
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
|
||||
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
|
||||
mesh_alpha = [first_latency, second_latency]
|
||||
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
|
||||
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
|
||||
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
|
||||
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
|
||||
|
||||
return mesh_alpha, mesh_beta
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -15,7 +16,8 @@ class DeviceMesh:
|
|||
|
||||
Arguments:
|
||||
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
|
||||
mesh_shape (torch.Size): shape of logical view.
|
||||
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
|
||||
mesh_shape (torch.Size, optional): shape of logical view.
|
||||
mesh_alpha (List[float], optional): coefficients used for computing
|
||||
communication cost (default: None)
|
||||
mesh_beta (List[float], optional): coefficients used for computing
|
||||
|
@ -28,15 +30,21 @@ class DeviceMesh:
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha=None,
|
||||
mesh_beta=None,
|
||||
init_process_group=False,
|
||||
need_flatten=True):
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
need_flatten: bool = True):
|
||||
self.physical_mesh_id = physical_mesh_id
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
if logical_mesh_id is None:
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
else:
|
||||
self._logical_mesh_id = logical_mesh_id
|
||||
self.mesh_shape = self._logical_mesh_id.shape
|
||||
|
||||
# map global rank into logical rank
|
||||
self.convert_map = {}
|
||||
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
|
||||
|
@ -54,8 +62,8 @@ class DeviceMesh:
|
|||
if self.need_flatten and self._logical_mesh_id.dim() > 1:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
self.mesh_beta)
|
||||
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
# self.mesh_beta)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
|
|
@ -1,24 +1,34 @@
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.module import _addindent
|
||||
from typing import Type, Dict, List, Any, Union, Optional, Set
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
|
||||
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
|
||||
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
COLOGM = True
|
||||
except:
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
COLOGM = False
|
||||
|
||||
if COLOGM:
|
||||
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: Graph,
|
||||
class_name: str = 'GraphModule',
|
||||
ckpt_codegen: bool = True):
|
||||
if ckpt_codegen:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
|
|
|
@ -9,6 +9,40 @@ def pipe_split():
|
|||
pass
|
||||
|
||||
|
||||
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In avgcompute_split_pass, we split module by the fwd flops.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
|
||||
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
|
||||
check_node = list(mod_graph.nodes)[0]
|
||||
if 'tensor_meta' not in check_node.meta:
|
||||
return balanced_split_pass(gm, pp_size)
|
||||
|
||||
total_fwd_flop = 0
|
||||
for node in mod_graph.nodes:
|
||||
total_fwd_flop += node.fwd_flop
|
||||
|
||||
partition_flop = total_fwd_flop // pp_size
|
||||
accumulate_fwd_flop = 0
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
if 'pipe_split' in node.name:
|
||||
continue
|
||||
accumulate_fwd_flop += node.fwd_flop
|
||||
if accumulate_fwd_flop >= partition_flop:
|
||||
total_fwd_flop = total_fwd_flop - accumulate_fwd_flop
|
||||
accumulate_fwd_flop = 0
|
||||
pp_size -= 1
|
||||
partition_flop = total_fwd_flop // pp_size
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In avgnode_split_pass, simpliy split graph by node number.
|
||||
|
@ -104,8 +138,10 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
|
|||
continue
|
||||
accumulate_node_size += node.node_size
|
||||
if accumulate_node_size >= partition_size:
|
||||
total_element_size = total_element_size - accumulate_node_size
|
||||
accumulate_node_size = 0
|
||||
pp_size -= 1
|
||||
partition_size = total_element_size // pp_size
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
gm.recompile()
|
||||
|
|
|
@ -112,7 +112,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
|
|
|
@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
aten.sub.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
|
@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.where.self,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
aten.fill_.Scalar
|
||||
] # yapf: disable
|
||||
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.types import _bool, _device, _dtype
|
||||
|
@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor):
|
|||
|
||||
_tensor: torch.Tensor
|
||||
|
||||
__slots__ = ['_tensor']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, fake_device=None):
|
||||
# Avoid multiple wrapping
|
||||
|
@ -47,7 +43,7 @@ class MetaTensor(torch.Tensor):
|
|||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=fake_device if fake_device is not None else elem.device,
|
||||
device=fake_device if fake_device is not None else torch.device('cpu'),
|
||||
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
|
@ -59,8 +55,8 @@ class MetaTensor(torch.Tensor):
|
|||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
|
||||
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
@ -76,13 +72,13 @@ class MetaTensor(torch.Tensor):
|
|||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
|
@ -118,23 +114,24 @@ class MetaTensor(torch.Tensor):
|
|||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||
"""
|
||||
# this imitates c++ function in the way of @overload
|
||||
device = None
|
||||
for arg in args:
|
||||
if isinstance(arg, str) or isinstance(arg, _device):
|
||||
device = arg
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
result = super().to(*args, **kwargs)
|
||||
if device is not None:
|
||||
result = MetaTensor(result, fake_device=device)
|
||||
return result
|
||||
fake_device = None
|
||||
|
||||
def replace(x):
|
||||
nonlocal fake_device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
fake_device = x
|
||||
return 'meta'
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, fake_device=fake_device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
if self.device.type == 'cuda':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cuda', **kwargs)
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
|
|
|
@ -13,6 +13,7 @@ def symbolic_trace(
|
|||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
trace_act_ckpt=False,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Symbolic tracing API
|
||||
|
@ -49,6 +50,6 @@ def symbolic_trace(
|
|||
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
||||
|
||||
"""
|
||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import enum
|
||||
import functools
|
||||
import operator
|
||||
import inspect
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
|
@ -286,7 +286,6 @@ class ColoTracer(Tracer):
|
|||
self.graph.lint()
|
||||
return self.graph
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_activation_checkpoint(self, enabled: bool):
|
||||
if enabled:
|
||||
|
@ -316,7 +315,6 @@ class ColoTracer(Tracer):
|
|||
# recover the checkpoint function upon exit
|
||||
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
|
||||
|
||||
|
||||
def _post_check(self, non_concrete_arg_names: Set[str]):
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
|
@ -385,18 +383,23 @@ def symbolic_trace(
|
|||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
trace_act_ckpt=False,
|
||||
) -> ColoGraphModule:
|
||||
if is_compatible_with_meta():
|
||||
if meta_args is not None:
|
||||
root.to(default_device())
|
||||
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
|
||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
root.cpu()
|
||||
else:
|
||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||
else:
|
||||
from .tracer import ColoTracer as OrigColoTracer
|
||||
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
|
||||
concrete_args=concrete_args,
|
||||
meta_args=meta_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
||||
|
||||
|
@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
|
|||
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
|
||||
node.kwargs)
|
||||
|
||||
|
||||
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
|
||||
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
||||
if kind == 'placeholder':
|
||||
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
|
||||
_truncate_suffix(target), None)
|
||||
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
|
||||
elif kind == 'get_attr':
|
||||
attr_itr = root
|
||||
atoms = target.split(".")
|
||||
|
@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
|
|||
else:
|
||||
if target not in _TensorPropertyMethod:
|
||||
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
elif kind == 'call_module':
|
||||
mod = root.get_submodule(target)
|
||||
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
|
@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
|
|||
meta_out = None
|
||||
return meta_out
|
||||
|
||||
|
||||
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
||||
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
|
||||
meta_out = meta_args[target]
|
||||
|
@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
|||
return meta_out
|
||||
|
||||
|
||||
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
|
||||
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
|
||||
result_graph = Graph()
|
||||
value_remap = {}
|
||||
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
||||
|
@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
|
|||
if target == torch.nn.functional.linear:
|
||||
if 'bias' in kwargs and kwargs['bias'] is not None:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
else:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_method.has(method):
|
||||
function_to_substitute = method_to_func_dict[method]
|
||||
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
|
||||
elif kind == "call_module":
|
||||
# if not hasattr(self, "orig_forward"):
|
||||
|
@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
|
|||
mod_type = type(mod)
|
||||
if bias_addition_module.has(mod_type) and mod.bias is not None:
|
||||
function_to_substitute = module_to_func_dict[mod_type]
|
||||
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
|
||||
if handle is not None:
|
||||
handle.generate()
|
||||
for node_inserted in tracer.graph.nodes:
|
||||
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
|
||||
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
|
||||
last_node = value_remap[node_inserted]
|
||||
value_remap[orig_node] = last_node
|
||||
else:
|
||||
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
|
||||
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
|
||||
|
||||
del tracer
|
||||
|
||||
gm.graph = result_graph
|
||||
gm.recompile()
|
||||
meta_prop_pass(gm, root_model, meta_args)
|
||||
|
||||
|
|
|
@ -6,17 +6,14 @@ import torch.nn as nn
|
|||
|
||||
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
def in_ddp(param: nn.Parameter) -> bool:
|
||||
return not getattr(param, '_ddp_to_ignore', False)
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
|
||||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||
"""
|
||||
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
|
||||
"""
|
||||
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
|
||||
params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
|
||||
params_size_arr = np.array(params_size)
|
||||
|
||||
std = np.std(params_size_arr)
|
||||
|
@ -56,7 +53,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
|
|||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in param_order.generate():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if not in_ddp(param):
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
|
||||
param_key = param.process_group.dp_world_size()
|
||||
|
|
|
@ -6,8 +6,14 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
|
||||
def safe_div(a, b):
|
||||
if a == 0:
|
||||
return 0
|
||||
return a / b
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
|
@ -16,7 +22,6 @@ def init_chunk_manager(model: nn.Module,
|
|||
search_range_mb: Optional[float] = None,
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
|
||||
|
||||
kwargs_dict = dict()
|
||||
|
||||
if hidden_dim:
|
||||
|
@ -34,7 +39,7 @@ def init_chunk_manager(model: nn.Module,
|
|||
if filter_exlarge_params:
|
||||
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
|
||||
|
||||
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
|
||||
params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
|
||||
total_size = sum(params_sizes) / 1024**2
|
||||
|
||||
dist.barrier()
|
||||
|
@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module,
|
|||
if dist.get_rank() == 0:
|
||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep='',
|
||||
flush=True)
|
||||
dist.barrier()
|
||||
|
|
|
@ -50,6 +50,17 @@ class GeminiManager:
|
|||
self._warmup = True
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def reset_attributes(self):
|
||||
self._compute_idx = -1
|
||||
self._h2d_volume = 0
|
||||
self._d2h_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def is_warmup(self):
|
||||
return self._warmup
|
||||
|
||||
def memstats(self):
|
||||
"""memstats
|
||||
|
||||
|
@ -73,12 +84,7 @@ class GeminiManager:
|
|||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.finish_collection()
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._h2d_volume = 0
|
||||
self._d2h_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
""" Adjust the layout of stateful tensors according to the information provided
|
||||
|
|
|
@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
|
||||
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
|
||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.gradient_accumulation import accumulate_gradient
|
||||
|
||||
from colossalai.engine.schedule import (
|
||||
InterleavedPipelineSchedule,
|
||||
NonPipelineSchedule,
|
||||
PipelineSchedule,
|
||||
get_tensor_shape,
|
||||
)
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
|
||||
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
|
@ -301,9 +300,9 @@ def initialize(model: nn.Module,
|
|||
model = model().to(get_current_device())
|
||||
|
||||
# optimizer maybe a optimizer_cls
|
||||
logger.warning("Initializing an non ZeRO model with optimizer class")
|
||||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model.parameters())
|
||||
logger.warning("Initializing an non ZeRO model with optimizer class")
|
||||
|
||||
if not use_zero:
|
||||
if is_using_sequence():
|
||||
|
|
|
@ -114,6 +114,13 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
self.scaled_masked_softmax = scaled_masked_softmax
|
||||
|
||||
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
|
@ -178,11 +185,5 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
|
||||
return probs
|
||||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
try:
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
def get_batch_per_block(self, sq, sk, b, np):
|
||||
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
|
|
@ -19,7 +19,7 @@ class CPUAdam(NVMeOptimizer):
|
|||
* Parameters on GPU and gradients on GPU is allowed.
|
||||
* Parameters on GPU and gradients on CPU is **not** allowed.
|
||||
|
||||
Requires ColossalAI to be installed via ``pip install .``.
|
||||
`CPUAdam` requires CUDA extensions which can be built during installation or runtime.
|
||||
|
||||
This version of CPU Adam accelates parameters updating on CPU with SIMD.
|
||||
Support of AVX2 or AVX512 is required.
|
||||
|
|
|
@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
|
|||
class FusedAdam(torch.optim.Optimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
Currently GPU-only. Requires ColossalAI to be installed via
|
||||
``pip install .``.
|
||||
`FusedAdam` requires CUDA extensions which can be built during installation or runtime.
|
||||
|
||||
This version of fused Adam implements 2 fusions.
|
||||
|
||||
|
|
|
@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
|
|||
class FusedLAMB(torch.optim.Optimizer):
|
||||
"""Implements LAMB algorithm.
|
||||
|
||||
Currently GPU-only. Requires ColossalAI to be installed via
|
||||
``pip install .``.
|
||||
`FusedLAMB` requires CUDA extensions which can be built during installation or runtime.
|
||||
|
||||
This version of fused LAMB implements 2 fusions.
|
||||
|
||||
|
|
|
@ -10,8 +10,7 @@ from colossalai.utils import multi_tensor_applier
|
|||
class FusedSGD(Optimizer):
|
||||
r"""Implements stochastic gradient descent (optionally with momentum).
|
||||
|
||||
Currently GPU-only. Requires ColossalAI to be installed via
|
||||
``pip install .``.
|
||||
`FusedSGD` requires CUDA extensions which can be built during installation or runtime.
|
||||
|
||||
This version of fused SGD implements 2 fusions.
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ class HybridAdam(NVMeOptimizer):
|
|||
* Parameters on GPU and gradients on GPU is allowed.
|
||||
* Parameters on GPU and gradients on CPU is **not** allowed.
|
||||
|
||||
Requires ColossalAI to be installed via ``pip install .``
|
||||
`HybriadAdam` requires CUDA extensions which can be built during installation or runtime.
|
||||
|
||||
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
||||
|
@ -12,7 +13,7 @@ from colossalai.gemini.chunk import Chunk, ChunkManager
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
@ -78,8 +79,16 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
||||
|
||||
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p, fp32_p in zip(params_list, module.fp32_params):
|
||||
ddp_param_list = []
|
||||
for name, param in module.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
if param.requires_grad:
|
||||
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
||||
"You should handle its optimizer update by yourself!")
|
||||
else:
|
||||
ddp_param_list.append(param)
|
||||
|
||||
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
if chunk_16 not in self.chunk16_set:
|
||||
chunk_16.l2_norm_flag = self.clipping_flag
|
||||
|
@ -140,6 +149,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _clear_global_norm(self) -> None:
|
||||
for c16 in self.chunk16_set:
|
||||
c16.l2_norm = None
|
||||
|
||||
def _calc_global_norm(self) -> float:
|
||||
norm_sqr: float = 0.0
|
||||
group_to_norm = dict()
|
||||
|
@ -201,6 +214,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
self.zero_grad() # reset all gradients
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
@ -285,6 +299,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
fake_params_list = list()
|
||||
|
||||
for param in group['params']:
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
chunk16 = self.chunk_manager.get_chunk(param)
|
||||
range_pair = get_range_pair(chunk16, param)
|
||||
if range_pair[0] >= range_pair[1]:
|
||||
|
|
|
@ -12,12 +12,14 @@ from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
from .reducer import Reducer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
|
@ -80,7 +82,7 @@ class ColoDDP(torch.nn.Module):
|
|||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
|
@ -115,7 +117,7 @@ class ColoDDP(torch.nn.Module):
|
|||
if self.rebuild_bucket:
|
||||
self.reducer.free()
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.grad.device.type != "cpu":
|
||||
p.grad = p._saved_grad
|
||||
|
@ -199,14 +201,18 @@ class ZeroDDP(ColoDDP):
|
|||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
|
||||
Defaults to False.
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False) -> None:
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
|
@ -231,8 +237,11 @@ class ZeroDDP(ColoDDP):
|
|||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
p.data = p.data.half()
|
||||
if strict_ddp_mode and not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
||||
fp32_data = p.data.float()
|
||||
|
@ -251,10 +260,11 @@ class ZeroDDP(ColoDDP):
|
|||
pin_memory=pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
|
||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
|
@ -266,19 +276,42 @@ class ZeroDDP(ColoDDP):
|
|||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
access_list = list(self.chunk_manager.accessed_chunks)
|
||||
# we need to scatter all accessed chunks and move them to their original places
|
||||
for chunk in access_list:
|
||||
assert chunk.can_release
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
first_param = next(iter(chunk.tensors_info))
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
# reset all recorded attributes
|
||||
self.gemini_manager.reset_attributes()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
# scatter chunks in the inference mode
|
||||
if not grad_flag:
|
||||
self._post_forward()
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
return outputs
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
p.grad = None
|
||||
|
||||
|
@ -331,12 +364,10 @@ class ZeroDDP(ColoDDP):
|
|||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
r"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
@ -346,7 +377,30 @@ class ZeroDDP(ColoDDP):
|
|||
|
||||
>>> module.state_dict().keys()
|
||||
['bias', 'weight']
|
||||
"""
|
||||
if strict:
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
|
||||
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return self._non_strict_state_dict(destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars,
|
||||
only_rank_0=only_rank_0)
|
||||
|
||||
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
|
||||
are shared with other parameters which have been included in the dictionary.
|
||||
When you need to load the state dict, you should set the argument `strict` to False.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
"""
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
|
@ -405,8 +459,14 @@ class ZeroDDP(ColoDDP):
|
|||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
# TODO: (HELSON) deal with ddp ignored parameters
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
ddp_param_list = []
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
ddp_param_list.append((name, param))
|
||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
|
@ -542,8 +602,16 @@ class ZeroDDP(ColoDDP):
|
|||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
ddp_param_list = []
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
load(name, param, param.copy_)
|
||||
else:
|
||||
ddp_param_list.append((name, param))
|
||||
|
||||
fp32_to_name = dict()
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
||||
if p is not None:
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ class GeminiDDP(ZeroDDP):
|
|||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
|
@ -54,4 +55,4 @@ class GeminiDDP(ZeroDDP):
|
|||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
||||
|
|
|
@ -47,30 +47,29 @@ def _get_shallow_copy_model(model: nn.Module):
|
|||
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
||||
But the new submodule and the old submodule share all attributes.
|
||||
"""
|
||||
name_to_module = dict()
|
||||
old_to_new = dict()
|
||||
for name, module in _get_dfs_module_list(model):
|
||||
new_module = copy(module)
|
||||
new_module._modules = OrderedDict()
|
||||
for subname, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
full_name = name + ('.' if name else '') + subname
|
||||
setattr(new_module, subname, name_to_module[full_name])
|
||||
name_to_module[name] = new_module
|
||||
return name_to_module['']
|
||||
setattr(new_module, subname, old_to_new[submodule])
|
||||
old_to_new[module] = new_module
|
||||
return old_to_new[model]
|
||||
|
||||
|
||||
def get_static_torch_model(gemini_ddp_model,
|
||||
def get_static_torch_model(zero_ddp_model,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
only_rank_0=True) -> torch.nn.Module:
|
||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||
You should notice that the original GeminiDDP model is not modified.
|
||||
"""Get a static torch.nn.Module model from the given ZeroDDP module.
|
||||
You should notice that the original ZeroDDP model is not modified.
|
||||
Thus, you can use the original model in further training.
|
||||
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||
|
||||
Args:
|
||||
gemini_ddp_model (GeminiDDP): a gemini ddp model
|
||||
zero_ddp_model (ZeroDDP): a zero ddp model
|
||||
device (torch.device): the device of the final torch model
|
||||
dtype (torch.dtype): the dtype of the final torch model
|
||||
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
||||
|
@ -78,11 +77,11 @@ def get_static_torch_model(gemini_ddp_model,
|
|||
Returns:
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
assert isinstance(gemini_ddp_model, GeminiDDP)
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||
|
||||
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
colo_model = gemini_ddp_model.module
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
|
||||
colo_model = zero_ddp_model.module
|
||||
torch_model = _get_shallow_copy_model(colo_model)
|
||||
|
||||
if not only_rank_0 or dist.get_rank() == 0:
|
||||
|
|
|
@ -211,7 +211,7 @@ class WorkerBase(ABC):
|
|||
refcount = 0
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
if refcount < lifecycle:
|
||||
if refcount <= lifecycle:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
|
||||
|
@ -390,7 +390,7 @@ class WorkerBase(ABC):
|
|||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank)
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
|
|
|
@ -29,9 +29,6 @@ class FillDrainWorker(WorkerBase):
|
|||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
|
||||
return target_key
|
||||
|
||||
|
||||
|
|
|
@ -1,22 +1,46 @@
|
|||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .activation_checkpoint import checkpoint
|
||||
from .checkpointing import load_checkpoint, save_checkpoint
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
|
||||
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param, disposable)
|
||||
from .common import (
|
||||
clip_grad_norm_fp32,
|
||||
conditional_context,
|
||||
copy_tensor_parallel_attributes,
|
||||
count_zeros_fp32,
|
||||
disposable,
|
||||
ensure_path_exists,
|
||||
free_port,
|
||||
is_ddp_ignored,
|
||||
is_dp_rank_0,
|
||||
is_model_parallel_parameter,
|
||||
is_no_pp_or_last_stage,
|
||||
is_tp_rank_0,
|
||||
is_using_ddp,
|
||||
is_using_pp,
|
||||
is_using_sequence,
|
||||
multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate,
|
||||
print_rank_0,
|
||||
switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param,
|
||||
)
|
||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction,
|
||||
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
|
||||
from .timer import MultiTimer, Timer
|
||||
from .memory import (
|
||||
colo_device_memory_capacity,
|
||||
colo_device_memory_used,
|
||||
colo_get_cpu_memory_capacity,
|
||||
colo_set_cpu_memory_capacity,
|
||||
colo_set_process_memory_fraction,
|
||||
report_memory_usage,
|
||||
)
|
||||
from .tensor_detector import TensorDetector
|
||||
from .timer import MultiTimer, Timer
|
||||
|
||||
__all__ = [
|
||||
'checkpoint',
|
||||
'free_port',
|
||||
'print_rank_0',
|
||||
'sync_model_param',
|
||||
'is_ddp_ignored',
|
||||
'is_dp_rank_0',
|
||||
'is_tp_rank_0',
|
||||
'is_no_pp_or_last_stage',
|
||||
|
|
|
@ -126,14 +126,18 @@ def is_model_parallel_parameter(p):
|
|||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def is_ddp_ignored(p):
|
||||
return getattr(p, '_ddp_to_ignore', False)
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
# we should not
|
||||
# we should not
|
||||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
|
|
@ -0,0 +1,440 @@
|
|||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import pprint
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_TorchFactoryMethod = [
|
||||
"arange",
|
||||
"empty",
|
||||
"eye",
|
||||
"full",
|
||||
"linspace",
|
||||
"logspace",
|
||||
"ones",
|
||||
"rand",
|
||||
"randn",
|
||||
"randint",
|
||||
"randperm",
|
||||
"zeros",
|
||||
"tensor",
|
||||
]
|
||||
|
||||
orig_empty = torch.empty # avoid override
|
||||
|
||||
scm = ShapeConsistencyManager()
|
||||
|
||||
|
||||
class LazyTensor(torch.Tensor):
|
||||
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
|
||||
|
||||
Usage:
|
||||
1. Use ``LazyTensor`` instead of ``torch.Tensor``.
|
||||
>>> x = LazyTensor(torch.zeros, 2, 3)
|
||||
>>> x += 1
|
||||
>>> y = x * x
|
||||
>>> y = y.cuda().half()
|
||||
>>> y[0, 0] = 0
|
||||
>>> y = y.materialize() # materialize the tensor
|
||||
>>> print(y)
|
||||
tensor([[0., 1., 1.],
|
||||
[1., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
||||
|
||||
2. Generate ``MetaTensor`` from ``LazyTensor``
|
||||
>>> x = LazyTensor(torch.zeros, 2, 3)
|
||||
>>> x.reshape(3, 2)
|
||||
>>> x = x.traceable() # generate ``MetaTensor``
|
||||
>>> print(x)
|
||||
MetaTensor(..., size=(3, 2), device=cpu, dtype=torch.float32)
|
||||
|
||||
3. Use ``LazyTensor`` to generate sharded ``nn.Parameter``.
|
||||
>>> x = LazyTensor(torch.zeros, 2, 3)
|
||||
>>> x.spec = ... # some ``ShardingSpec``
|
||||
>>> x.distribute() # distribute the tensor according to the ``ShardingSpec``
|
||||
|
||||
Warnings:
|
||||
1. Cases that ``LazyTensor`` can't deal with.
|
||||
>>> x = LazyTensor(torch.ones, 2, 3)
|
||||
>>> x[0, 0] = -x[0, 0] # this will cause infinite recursion
|
||||
|
||||
2. ``LazyTensor.materialize()`` can't be called multiple times.
|
||||
>>> x = LazyTensor(torch.ones, 2, 3)
|
||||
>>> x.materialize()
|
||||
>>> x.materialize() # this is disallowed
|
||||
"""
|
||||
|
||||
_repr = True
|
||||
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
|
||||
_cached_data: Optional[torch.Tensor] = None # materialized data
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, func, *args, dtype=None, device=None, **kwargs):
|
||||
elem = func(*args, dtype=dtype, device='meta', **kwargs)
|
||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=device if device is not None else torch.device('cpu'),
|
||||
requires_grad=elem.requires_grad)
|
||||
r._meta_data = MetaTensor(elem, fake_device=device)
|
||||
return r
|
||||
|
||||
def __init__(self, func, *args, dtype=None, device=None, **kwargs):
|
||||
self._factory_method = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs)
|
||||
self._cached_buffer = list() # (func, args, kwargs)
|
||||
self._spec = None
|
||||
self._data = self
|
||||
|
||||
def __repr__(self):
|
||||
if self._repr:
|
||||
# avoid recursive representation
|
||||
self.__class__._repr = False
|
||||
s = f'LazyTensor(..., size={tuple(self._meta_data.shape)}, device={self._meta_data.device}, dtype={self._meta_data.dtype})\n'\
|
||||
f'factory method: {self._factory_method}\n'\
|
||||
f'cached: {pprint.pformat(self._cached_buffer) if self._cached_data is None else self._cached_data}\n'\
|
||||
f'spec: {self._spec}'
|
||||
self.__class__._repr = True
|
||||
return s
|
||||
else:
|
||||
return 'LazyTensor(...)'
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
|
||||
|
||||
Warnings:
|
||||
Calling ``self.materialize()`` will clear all cached sequence and factory method,
|
||||
because we don't allow materialize the same ``LazyTensor`` twice.
|
||||
This is mentioned in the paper: https://arxiv.org/pdf/2102.13267.pdf (Part 4.3).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The materialized tensor.
|
||||
"""
|
||||
target = self._data._realize_cached_data()
|
||||
if isinstance(self, nn.Parameter):
|
||||
target = nn.Parameter(target, requires_grad=self.requires_grad)
|
||||
self._clear_all()
|
||||
return target
|
||||
|
||||
def traceable(self) -> MetaTensor:
|
||||
"""Generate ``MetaTensor`` from ``LazyTensor``. (Mostly for tracing)
|
||||
|
||||
Returns:
|
||||
MetaTensor: The generated ``MetaTensor``.
|
||||
"""
|
||||
if isinstance(self, nn.Parameter):
|
||||
return nn.Parameter(self._meta_data, requires_grad=self.requires_grad)
|
||||
else:
|
||||
return self._meta_data
|
||||
|
||||
def distribute(self) -> torch.Tensor:
|
||||
"""Distribute the ``LazyTensor`` according to the ``ShardingSpec``.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The sharded tensor.
|
||||
"""
|
||||
if self._spec is None:
|
||||
raise RuntimeError('ShardingSpec is not set for\n{self}')
|
||||
spec, device_mesh = self._spec, self._spec.device_mesh
|
||||
target = self.materialize()
|
||||
|
||||
# TODO(some man): better not be coupled with auto-parallel
|
||||
target.data = scm.apply_for_autoparallel_runtime(target.data, ShardingSpec(device_mesh, target.shape, {}),
|
||||
spec).detach().clone()
|
||||
return target
|
||||
|
||||
def _realize_cached_data(self) -> torch.Tensor:
|
||||
# self._cached_data should be generated after the first call of this function
|
||||
if self._cached_data is None:
|
||||
if self._factory_method is not None:
|
||||
# apply factory method
|
||||
func, args, kwargs = self._factory_method
|
||||
|
||||
# apply cached sequence
|
||||
self._cached_data = self._apply_cache_buffer(func(*args, **kwargs))
|
||||
else:
|
||||
# apply cached sequence only
|
||||
self._cached_data = self._apply_cache_buffer()
|
||||
return self._cached_data
|
||||
|
||||
def _apply_cache_buffer(self, target=None) -> torch.Tensor:
|
||||
# dump all cached sequence
|
||||
# super-dainiu: support methods for single Tensor only
|
||||
def replace(x):
|
||||
if x is self:
|
||||
return target
|
||||
elif isinstance(x, LazyTensor):
|
||||
return x._realize_cached_data()
|
||||
return x
|
||||
|
||||
packed = None
|
||||
|
||||
for (func, args, kwargs) in self._cached_buffer:
|
||||
if func == torch.Tensor.requires_grad_:
|
||||
packed = func, args, kwargs # requires grad should be set at last
|
||||
else:
|
||||
o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
|
||||
|
||||
# super-dainiu: set requires_grad after all inplace-ops are done
|
||||
if packed is not None:
|
||||
func, args, kwargs = packed
|
||||
func(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
|
||||
return target
|
||||
|
||||
# clear all means:
|
||||
# 1. clear factory method
|
||||
# 2. clear cached sequence
|
||||
# 3. clear cached data
|
||||
def _clear_all(self):
|
||||
self._cached_data = None
|
||||
self._cached_buffer = None
|
||||
self._data = None
|
||||
gc.collect() # avoid memory leak
|
||||
|
||||
# cache everything with __torch_function__
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
target = None
|
||||
|
||||
if isinstance(func, torch._C.ScriptMethod):
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, LazyTensor):
|
||||
return x._meta_data
|
||||
return x
|
||||
|
||||
target: LazyTensor = args[0].clone()
|
||||
target._cached_buffer.append((func, args, kwargs))
|
||||
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
|
||||
**tree_map(unwrap, kwargs))
|
||||
|
||||
else:
|
||||
|
||||
def unwrap(x):
|
||||
nonlocal target
|
||||
if isinstance(x, LazyTensor):
|
||||
target = x if (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||
or func.__name__ == "__setitem__") else x.clone()
|
||||
target._cached_buffer.append((func, args, kwargs))
|
||||
return x._meta_data
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
o = func(*args, **kwargs)
|
||||
|
||||
if isinstance(o, MetaTensor):
|
||||
target._meta_data = o
|
||||
return target
|
||||
else:
|
||||
return o
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
pass # skip
|
||||
|
||||
def clone(self) -> "LazyTensor":
|
||||
"""Create a new ``LazyTensor`` with same cached sequence and factory method.
|
||||
|
||||
Returns:
|
||||
LazyTensor: the new ``LazyTensor``
|
||||
"""
|
||||
target = LazyTensor(orig_empty, 0, dtype=self._meta_data.dtype, device=self._meta_data.device)
|
||||
target._factory_method = None
|
||||
target._cached_buffer = list()
|
||||
target._meta_data = self._meta_data.clone()
|
||||
target._cached_data = self._cached_data.clone() if self._cached_data is not None else None
|
||||
target._spec = copy.deepcopy(self._spec)
|
||||
return target
|
||||
|
||||
def detach(self) -> "LazyTensor":
|
||||
target = self.clone()
|
||||
target._cached_buffer.append((torch.Tensor.detach_, (self,), {}))
|
||||
return target
|
||||
|
||||
@property
|
||||
def spec(self) -> ShardingSpec:
|
||||
return self._spec
|
||||
|
||||
@spec.setter
|
||||
def spec(self, other: ShardingSpec):
|
||||
self._spec = other
|
||||
|
||||
@property
|
||||
def data(self) -> "LazyTensor":
|
||||
return self._data.detach()
|
||||
|
||||
@data.setter
|
||||
def data(self, other: "LazyTensor") -> "LazyTensor":
|
||||
"""This avoid the following infinite recursion, which is very common in ``nn.Module`` initialization.
|
||||
|
||||
Usage:
|
||||
>>> a = LazyTensor(torch.empty, 0, dtype=torch.float32, device='cpu')
|
||||
>>> b = a.cuda()
|
||||
>>> a.data = b
|
||||
"""
|
||||
self._data = other
|
||||
|
||||
|
||||
class LazyInitContext():
|
||||
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
|
||||
|
||||
Usage:
|
||||
1. The model is initialized, but no real memory is allocated.
|
||||
>>> ctx = LazyInitContext()
|
||||
>>> with ctx:
|
||||
>>> model = MyModel().cuda()
|
||||
|
||||
2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated.
|
||||
>>> with ctx.traceable(model):
|
||||
>>> gm = symbolic_trace(model, meta_args=meta_args)
|
||||
>>> # Solve the execution strategy and apply the strategy to the model
|
||||
>>> strategy = StrategyAndSpec()
|
||||
|
||||
3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device)
|
||||
>>> model = ctx.materialize(model)
|
||||
|
||||
3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario)
|
||||
>>> model = apply_strategy_to_all_params(model, strategy)
|
||||
>>> model = ctx.distribute(model)
|
||||
|
||||
Warnings:
|
||||
This API is still experimental and further modifications can be made to it.
|
||||
For example:
|
||||
1. Quantization strategies can be applied before allocating real memory.
|
||||
2. Lazy initialization seems slower than normal initialization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.overrides = {}
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def wrap_factory_method(target):
|
||||
# factory functions (eg. torch.empty())
|
||||
def wrapper(*args, **kwargs):
|
||||
return LazyTensor(target, *args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
def wrap_factory_like_method(orig_target, target):
|
||||
# factory_like functions (eg. torch.empty_like())
|
||||
def wrapper(*args, **kwargs):
|
||||
orig_t = args[0]
|
||||
return LazyTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
self.overrides = {
|
||||
target: wrap_factory_method(getattr(torch, target))
|
||||
for target in _TorchFactoryMethod
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
|
||||
self.overrides.update({
|
||||
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
|
||||
for target in _TorchFactoryMethod
|
||||
if callable(getattr(torch, target + '_like', None))
|
||||
})
|
||||
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, orig)
|
||||
|
||||
@staticmethod
|
||||
def materialize(module: torch.nn.Module):
|
||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Target ``nn.Module``
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def init_recursively(module: nn.Module):
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, name, param.materialize())
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, name, buf.materialize())
|
||||
|
||||
init_recursively(module)
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def distribute(module: torch.nn.Module):
|
||||
"""Initialize and shard all ``nn.Parameter`` from ``LazyTensor``.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Sharded target ``nn.Module``
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def init_recursively(module: nn.Module):
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, name, param.distribute())
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, name, buf.distribute())
|
||||
|
||||
init_recursively(module)
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def traceable(module: torch.nn.Module):
|
||||
"""Initialize all ``nn.Parameters`` as ``MetaTensor``. This enables ``ColoTracer`` with control flow.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Traceable ``nn.Module`` with ``MetaTensor`` as parameters.
|
||||
"""
|
||||
orig_val = dict()
|
||||
|
||||
def init_recursively(module: nn.Module):
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, name, param.traceable())
|
||||
orig_val[(module, name)] = param
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, name, buf.traceable())
|
||||
orig_val[(module, name)] = buf
|
||||
|
||||
init_recursively(module)
|
||||
|
||||
yield
|
||||
|
||||
# restore original values
|
||||
for (module, name), val in orig_val.items():
|
||||
setattr(module, name, val)
|
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -7,6 +8,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import is_model_parallel_parameter
|
||||
|
||||
|
||||
|
@ -101,7 +103,11 @@ def split_half_float_double(tensor_list):
|
|||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
|
||||
def reduce_tensor_dp_group(tensor: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dst_local_rank: Optional[int] = None,
|
||||
dst_global_rank: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
|
@ -114,7 +120,7 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
|
|||
:type tensor: torch.Tensor
|
||||
:type dtype: torch.dtype, optional
|
||||
:type dst_rank: int, optional
|
||||
:type parallel_mode: ParallelMode, optional
|
||||
:type pg: ProcessGroup, optional
|
||||
"""
|
||||
# use the original dtype
|
||||
if dtype is None:
|
||||
|
@ -126,25 +132,22 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
|
|||
else:
|
||||
tensor_to_reduce = tensor
|
||||
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
group = gpc.get_group(parallel_mode)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
tensor_to_reduce.div_(world_size)
|
||||
|
||||
# if rank is None, all reduce will be used
|
||||
# else, reduce is used
|
||||
use_all_reduce = dst_rank is None
|
||||
use_all_reduce = dst_local_rank is None
|
||||
|
||||
if use_all_reduce:
|
||||
dist.all_reduce(tensor_to_reduce, group=group)
|
||||
else:
|
||||
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
||||
global_rank = ranks_in_group[dst_rank]
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
|
||||
|
||||
# recover the original dtype
|
||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
if use_all_reduce or dst_rank == local_rank:
|
||||
local_rank = dist.get_rank(group=group)
|
||||
if use_all_reduce or dst_local_rank == local_rank:
|
||||
tensor.copy_(tensor_to_reduce)
|
||||
|
||||
return tensor
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, dp_parallel_mode):
|
||||
super().__init__(dp_parallel_mode)
|
||||
self._grads = dict()
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
||||
|
@ -20,25 +18,24 @@ class BucketStore(BaseStore):
|
|||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
|
||||
def add_grad(self, tensor, reduce_rank: int = None):
|
||||
self._grads[reduce_rank].append(tensor)
|
||||
|
||||
def add_param(self, tensor, reduce_rank: int = None):
|
||||
self._params[reduce_rank].append(tensor)
|
||||
|
||||
def reset(self):
|
||||
keys = [None] + list(range(self._world_size))
|
||||
self._grads = {rank: [] for rank in keys}
|
||||
self._params = {rank: [] for rank in keys}
|
||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
||||
|
||||
def reset_by_rank(self, reduce_rank=None):
|
||||
self._grads[reduce_rank] = []
|
||||
self._params[reduce_rank] = []
|
||||
self._num_elements_in_bucket[reduce_rank] = 0
|
||||
|
||||
def get_grad(self, reduce_rank: int = None):
|
||||
return self._grads[reduce_rank]
|
||||
param_list = self.get_param(reduce_rank)
|
||||
for param in param_list:
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
return [param.grad for param in param_list]
|
||||
|
||||
def get_param(self, reduce_rank: int = None):
|
||||
return self._params[reduce_rank]
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, dp_paralle_mode):
|
||||
super().__init__(dp_paralle_mode)
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
# param partitioning data structures
|
||||
self._fp16_param_to_rank = dict()
|
||||
self._rank_groupid_to_fp16_param_list = dict()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -10,6 +10,7 @@ from colossalai.context import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
|
@ -18,7 +19,7 @@ from ._utils import (
|
|||
flatten,
|
||||
get_grad_accumulate_object,
|
||||
has_inf_or_nan,
|
||||
reduce_tensor,
|
||||
reduce_tensor_dp_group,
|
||||
release_param_grad,
|
||||
split_half_float_double,
|
||||
sync_param,
|
||||
|
@ -33,35 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
|
||||
# grad scaler config
|
||||
initial_scale=2**16,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
hysteresis=2,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
backoff_factor: float = .5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
|
||||
# grad clipping
|
||||
clip_grad_norm=0.0,
|
||||
verbose=False,
|
||||
|
||||
# communication
|
||||
reduce_bucket_size=1024 * 1024,
|
||||
communication_dtype=None,
|
||||
overlap_communication=False,
|
||||
|
||||
# stage 2
|
||||
partition_grad=False,
|
||||
dp_parallel_mode=ParallelMode.DATA,
|
||||
mp_parallel_mode=ParallelMode.MODEL,
|
||||
|
||||
# cpu offload
|
||||
cpu_offload=False,
|
||||
|
||||
# forced dtype
|
||||
forced_dtype=None):
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
# TODO: add support for
|
||||
# 1. fp16 master weights
|
||||
|
@ -76,21 +63,32 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# stage 2
|
||||
self._partition_grads = partition_grad
|
||||
|
||||
# cpu_offload
|
||||
self._cpu_offload = cpu_offload
|
||||
|
||||
# get process groups
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
self._mp_parallel_mode = mp_parallel_mode
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
colo_pg = self._search_colo_process_group()
|
||||
if isinstance(colo_pg, ProcessGroup):
|
||||
self._local_rank = colo_pg.dp_local_rank()
|
||||
self._world_size = colo_pg.dp_world_size()
|
||||
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
|
||||
self._dp_torch_group = colo_pg.dp_process_group()
|
||||
self._mp_torch_group = None
|
||||
if colo_pg.tp_world_size() > 1:
|
||||
self._mp_torch_group = colo_pg.tp_process_group()
|
||||
elif colo_pg is None:
|
||||
dp_parallel_mode = ParallelMode.DATA
|
||||
mp_parallel_mode = ParallelMode.MODEL
|
||||
|
||||
self._dp_group = gpc.get_group(dp_parallel_mode)
|
||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
||||
self._mp_group = gpc.get_group(mp_parallel_mode)
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
self._mp_parallel_mode = mp_parallel_mode
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
|
||||
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
|
||||
self._mp_torch_group = None
|
||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
||||
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
|
||||
else:
|
||||
self._mp_group = None
|
||||
|
||||
raise NotImplementedError
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
@ -126,9 +124,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(self._dp_parallel_mode)
|
||||
self._grad_store = GradientStore(self._dp_parallel_mode)
|
||||
self._bucket_store = BucketStore(self._dp_parallel_mode)
|
||||
self._param_store = ParameterStore(self._dp_torch_group)
|
||||
self._grad_store = GradientStore(self._dp_torch_group)
|
||||
self._bucket_store = BucketStore(self._dp_torch_group)
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -209,6 +207,30 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
for param in group_params:
|
||||
assert param.dtype == self._dtype, \
|
||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
|
||||
def _search_colo_process_group(self):
|
||||
colo_flag = False
|
||||
colo_pg = None
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
for param in group_params:
|
||||
if isinstance(param, ColoParameter):
|
||||
colo_flag = True
|
||||
if colo_pg is None:
|
||||
colo_pg = param.get_process_group()
|
||||
else:
|
||||
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
|
||||
elif colo_flag:
|
||||
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
|
||||
return colo_pg
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
params_per_rank = [[] for _ in range(self._world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._world_size)]
|
||||
|
@ -223,22 +245,16 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
||||
if self._verbose:
|
||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
|
||||
ranks=[0],
|
||||
parallel_mode=self._dp_parallel_mode)
|
||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
|
||||
return params_per_rank
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
for param in group_params:
|
||||
assert param.dtype == self._dtype, \
|
||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
###########################
|
||||
# Backward Reduction Hook #
|
||||
###########################
|
||||
|
||||
###########################################################
|
||||
# Backward Reduction Hook
|
||||
###########################################################
|
||||
def _grad_handler(self, param, grad, reduce_rank):
|
||||
self._add_to_reduction_bucket(param, reduce_rank)
|
||||
return grad
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
|
@ -256,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
else:
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank):
|
||||
# get the AccumulateGrad object of the param itself
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
|
||||
|
||||
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank)
|
||||
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args):
|
||||
reduction_func()
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduce_global_rank = None
|
||||
if reduce_rank is not None:
|
||||
reduce_global_rank = self._dp_global_ranks[reduce_rank]
|
||||
reduced_flat = reduce_tensor_dp_group(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_local_rank=reduce_rank,
|
||||
dst_global_rank=reduce_global_rank,
|
||||
group=self._dp_torch_group)
|
||||
|
||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
|
||||
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
def _reduce_grads(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
|
||||
bucket_size=bucket_size,
|
||||
reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_grads_in_bucket(self, reduce_rank=None):
|
||||
#######################
|
||||
# Reduction Functions #
|
||||
#######################
|
||||
|
||||
def _run_reduction(self, reduce_rank=None):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
self._reduce_grads(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
|
@ -339,46 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
def _add_to_reduction_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._run_reduction(reduce_rank)
|
||||
|
||||
##############################
|
||||
# Reduction Utility Function #
|
||||
##############################
|
||||
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=self._dp_parallel_mode)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
################################
|
||||
# torch.optim.Optimizer methods
|
||||
|
@ -443,8 +445,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
||||
rank=self._local_rank),
|
||||
dp_group=self._dp_group,
|
||||
mp_group=self._mp_group)
|
||||
dp_group=self._dp_torch_group,
|
||||
mp_group=self._mp_torch_group)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
|
@ -482,9 +484,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# broadcast the updated model weights
|
||||
handles = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._world_size):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
|
||||
for index in range(self._world_size):
|
||||
rank = self._dp_global_ranks[index]
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
|
@ -506,11 +509,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
|
||||
if self._mp_torch_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
|
@ -569,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
self._reduce_and_remove_grads_by_bucket(param)
|
||||
self._add_to_reduction_bucket(param)
|
||||
|
||||
# we need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
self._reduce_grads_in_bucket()
|
||||
self._run_reduction()
|
||||
|
||||
def _reduce_grad_stage2(self):
|
||||
# when partition_grads is True, reduction hooks
|
||||
|
@ -581,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# only need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
for reduce_rank in range(self._world_size):
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
||||
self._run_reduction(reduce_rank)
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -24,7 +25,7 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
|
@ -37,7 +38,7 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
def post_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
for p in params:
|
||||
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
|
||||
self._chunk_manager.trans_tensor_state(p, tensor_state)
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
FROM hpcaitech/cuda-conda:11.3
|
||||
|
||||
# install torch
|
||||
RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
|
||||
RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||
|
||||
# install apex
|
||||
RUN git clone https://github.com/NVIDIA/apex && \
|
||||
cd apex && \
|
||||
pip install packaging && \
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
|
||||
|
||||
# install colossalai
|
||||
RUN git clone https://github.com/hpcaitech/ColossalAI.git \
|
||||
&& cd ./ColossalAI \
|
||||
&& pip install -v --no-cache-dir .
|
||||
&& CUDA_EXT=1 pip install -v --no-cache-dir .
|
||||
|
||||
# install titans
|
||||
RUN pip install --no-cache-dir titans
|
||||
|
|
|
@ -1,28 +1,40 @@
|
|||
## Examples folder document
|
||||
# Colossal-AI Examples
|
||||
|
||||
## Table of Contents
|
||||
<ul>
|
||||
<li><a href="#Example-folder-description">Example folder description</a> </li>
|
||||
<li><a href="#Integrate-Your-Example-With-System-Testing">Integrate Your Example With System Testing</a> </li>
|
||||
</ul>
|
||||
|
||||
## Example folder description
|
||||
- [Colossal-AI Examples](#colossal-ai-examples)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Overview](#overview)
|
||||
- [Folder Structure](#folder-structure)
|
||||
- [Integrate Your Example With Testing](#integrate-your-example-with-testing)
|
||||
|
||||
This folder provides several examples using colossalai. The images folder includes model like diffusion, dreambooth and vit. The language folder includes gpt, opt, palm and roberta. The tutorial folder is for concept illustration, such as auto-parallel, hybrid-parallel and so on.
|
||||
## Overview
|
||||
|
||||
This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications.
|
||||
|
||||
## Integrate Your Example With System Testing
|
||||
## Folder Structure
|
||||
|
||||
For example code contributor, to meet the expectation and test your code automatically using github workflow function, here are several steps:
|
||||
```text
|
||||
└─ examples
|
||||
└─ images
|
||||
└─ vit
|
||||
└─ test_ci.sh
|
||||
└─ train.py
|
||||
└─ README.md
|
||||
└─ ...
|
||||
└─ ...
|
||||
```
|
||||
|
||||
## Integrate Your Example With Testing
|
||||
|
||||
- (must) Have a test_ci.sh file in the folder like shown below in 'File Structure Chart'
|
||||
- The dataset should be located in the company's machine and can be announced using environment variable and thus no need for a separate terminal command.
|
||||
- The model parameters should be small to allow fast testing.
|
||||
- File Structure Chart
|
||||
Regular checks are important to ensure that all examples run without apparent bugs and stay compatible with the latest API.
|
||||
Colossal-AI runs workflows to check for examples on a on-pull-request and weekly basis.
|
||||
When a new example is added or changed, the workflow will run the example to test whether it can run.
|
||||
Moreover, Colossal-AI will run testing for examples every week.
|
||||
|
||||
└─examples
|
||||
└─images
|
||||
└─vit
|
||||
└─requirements.txt
|
||||
└─test_ci.sh
|
||||
Therefore, it is essential for the example contributors to know how to integrate your example with the testing workflow. Simply, you can follow the steps below.
|
||||
|
||||
1. Create a script called `test_ci.sh` in your example folder
|
||||
2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
|
||||
3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
|
||||
4. Implement the logic such as dependency setup and example execution
|
||||
|
|
|
@ -26,6 +26,16 @@ Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1]
|
|||
|
||||
More details can be found in our [blog of Stable Diffusion v1](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) and [blog of Stable Diffusion v2](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0).
|
||||
|
||||
|
||||
## Roadmap
|
||||
This project is in rapid development.
|
||||
|
||||
- [X] Train a stable diffusion model v1/v2 from scatch
|
||||
- [X] Finetune a pretrained Stable diffusion v1 model
|
||||
- [X] Inference a pretrained model using PyTorch
|
||||
- [ ] Finetune a pretrained Stable diffusion v2 model
|
||||
- [ ] Inference a pretrained model using TensoRT
|
||||
|
||||
## Installation
|
||||
|
||||
### Option #1: install from source
|
||||
|
@ -123,7 +133,7 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
|||
|
||||
### stable-diffusion-v1-5 from runway
|
||||
|
||||
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml
|
||||
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
|
||||
|
||||
```
|
||||
git lfs install
|
||||
|
@ -156,7 +166,7 @@ You can change the trainging config in the yaml file
|
|||
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
|
||||
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
|
||||
|
||||
## Finetune Example
|
||||
## Finetune Example (Work In Progress)
|
||||
### Training on Teyvat Datasets
|
||||
|
||||
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
|
||||
|
|
|
@ -153,7 +153,8 @@ def parse_args(input_args=None):
|
|||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
help=
|
||||
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
|
@ -355,10 +356,14 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
|||
|
||||
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
if args.seed is not None:
|
||||
gpc.set_seed(args.seed)
|
||||
if args.seed is None:
|
||||
colossalai.launch_from_torch(config={})
|
||||
else:
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
|
||||
local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_images_dir = Path(args.class_data_dir)
|
||||
|
@ -387,7 +392,7 @@ def main(args):
|
|||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
|
||||
disable=not local_rank == 0,
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
|
@ -399,7 +404,7 @@ def main(args):
|
|||
del pipeline
|
||||
|
||||
# Handle the repository creation
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
if local_rank == 0:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
|
@ -464,8 +469,9 @@ def main(args):
|
|||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
|
||||
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
|
||||
|
@ -554,7 +560,7 @@ def main(args):
|
|||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps
|
||||
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****", ranks=[0])
|
||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
||||
|
@ -566,7 +572,7 @@ def main(args):
|
|||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0)
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
|
@ -643,7 +649,7 @@ def main(args):
|
|||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
if local_rank == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=torch_unet,
|
||||
|
@ -658,7 +664,7 @@ def main(args):
|
|||
torch.cuda.synchronize()
|
||||
unet = get_static_torch_model(unet)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
if local_rank == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=unet,
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 8
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 3
|
||||
WARMUP_EPOCHS = 1
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 32
|
||||
DEPTH = 2
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 4
|
||||
NUM_CLASSES = 10
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
USE_DDP = True
|
||||
TP_WORLD_SIZE = 2
|
||||
TP_TYPE = 'row'
|
||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
gradient_accumulation = 2
|
||||
|
||||
LOG_PATH = "./log_ci"
|
|
@ -1,2 +1,8 @@
|
|||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
||||
numpy>=1.24.1
|
||||
timm>=0.6.12
|
||||
titans>=0.0.7
|
||||
tqdm>=4.61.2
|
||||
transformers>=4.25.1
|
||||
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
export OMP_NUM_THREADS=4
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
# train
|
||||
colossalai run \
|
||||
--nproc_per_node 4 train.py \
|
||||
--config configs/vit_1d_tp2_ci.py \
|
||||
--dummy_data
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
from timm.models.vision_transformer import _create_vision_transformer
|
||||
from titans.dataloader.imagenet import build_dali_imagenet
|
||||
from tqdm import tqdm
|
||||
from vit import DummyDataLoader
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -56,8 +57,8 @@ def init_spec_func(model, tp_type):
|
|||
def train_imagenet():
|
||||
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=True, action='store_true')
|
||||
parser.add_argument('--resume_from', default=False)
|
||||
parser.add_argument('--resume_from', default=False, action='store_true')
|
||||
parser.add_argument('--dummy_data', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
@ -74,10 +75,22 @@ def train_imagenet():
|
|||
logger.log_to_file(log_path)
|
||||
|
||||
logger.info('Build data loader', ranks=[0])
|
||||
root = os.environ['DATA']
|
||||
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
||||
train_batch_size=gpc.config.BATCH_SIZE,
|
||||
test_batch_size=gpc.config.BATCH_SIZE)
|
||||
if not args.dummy_data:
|
||||
root = os.environ['DATA']
|
||||
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
||||
train_batch_size=gpc.config.BATCH_SIZE,
|
||||
test_batch_size=gpc.config.BATCH_SIZE)
|
||||
else:
|
||||
train_dataloader = DummyDataLoader(length=10,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
category=gpc.config.NUM_CLASSES,
|
||||
image_size=gpc.config.IMG_SIZE,
|
||||
return_dict=False)
|
||||
test_dataloader = DummyDataLoader(length=5,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
category=gpc.config.NUM_CLASSES,
|
||||
image_size=gpc.config.IMG_SIZE,
|
||||
return_dict=False)
|
||||
|
||||
logger.info('Build model', ranks=[0])
|
||||
|
||||
|
|
|
@ -32,21 +32,24 @@ class DummyDataGenerator(ABC):
|
|||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
batch_size = 4
|
||||
channel = 3
|
||||
category = 8
|
||||
image_size = 224
|
||||
|
||||
def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
|
||||
super().__init__(length)
|
||||
self.batch_size = batch_size
|
||||
self.channel = channel
|
||||
self.category = category
|
||||
self.image_size = image_size
|
||||
self.return_dict = return_dict
|
||||
|
||||
def generate(self):
|
||||
image_dict = {}
|
||||
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
|
||||
DummyDataLoader.channel,
|
||||
DummyDataLoader.image_size,
|
||||
DummyDataLoader.image_size,
|
||||
device=get_current_device()) * 2 - 1
|
||||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
|
||||
image_dict['pixel_values'] = torch.rand(
|
||||
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
|
||||
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
|
||||
dtype=torch.int64,
|
||||
device=get_current_device())
|
||||
if not self.return_dict:
|
||||
return image_dict['pixel_values'], image_dict['label']
|
||||
return image_dict
|
||||
|
||||
|
||||
|
|
|
@ -39,9 +39,15 @@ If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-
|
|||
For simplicity, the input data is randonly generated here.
|
||||
|
||||
## Training
|
||||
We provide two solutions. One utilizes the hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism.
|
||||
The other one uses Pipeline Parallelism Only.
|
||||
In the future, we are going merge them together and they can be used orthogonally to each other.
|
||||
We provide two stable solutions.
|
||||
One utilizes the Gemini to implement hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism for a huggingface GPT model.
|
||||
The other one use [Titans](https://github.com/hpcaitech/Titans), a distributed executed model zoo maintained by ColossalAI,to implement the hybrid parallel strategies of TP + ZeRO + PP.
|
||||
|
||||
We recommend using Gemini to qucikly run your model in a distributed manner.
|
||||
It doesn't require significant changes to the model structures, therefore you can apply it on a new model easily.
|
||||
And use Titans as an advanced weapon to pursue a more extreme performance.
|
||||
Titans has included the some typical models, such as Vit and GPT.
|
||||
However, it requires some efforts to start if facing a new model structure.
|
||||
|
||||
### GeminiDPP/ZeRO + Tensor Parallelism
|
||||
```bash
|
||||
|
@ -56,6 +62,11 @@ The `train_gpt_demo.py` provides three distributed plans, you can choose the pla
|
|||
- Pytorch DDP
|
||||
- Pytorch ZeRO
|
||||
|
||||
### Titans (Tensor Parallelism) + ZeRO + Pipeline Parallelism
|
||||
|
||||
Titans provides a customized GPT model, which uses distributed operators as building blocks.
|
||||
In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
|
||||
You can switch parallel strategies using a config file.
|
||||
|
||||
## Performance
|
||||
|
||||
|
|
|
@ -16,14 +16,14 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.initialize import launch_from_torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 128
|
||||
HIDDEN_DIM = 3072
|
||||
BATCH_SIZE = 16
|
||||
SEQ_LENGTH = 1024
|
||||
HIDDEN_DIM = 4096
|
||||
NUM_HEADS = 16
|
||||
NUM_LAYERS = 1
|
||||
NUM_LAYERS = 4
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
FP16 = False
|
||||
FP16 = True
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
|
@ -40,7 +40,7 @@ def get_mem_info(prefix=''):
|
|||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
|
@ -66,13 +66,7 @@ def main():
|
|||
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
|
||||
}
|
||||
|
||||
# Both device mesh initialization and model initialization will be integrated into autoparallelize
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# Enable auto-parallel
|
||||
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
|
||||
gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
|
||||
|
||||
# print solution on rank 0
|
||||
if gpc.get_global_rank() == 0:
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
|
@ -120,7 +120,7 @@ def run_master(args):
|
|||
logger.info(f'{rank=} numel in the partition:{numel}')
|
||||
|
||||
# build optim
|
||||
pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
|
||||
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
|
||||
|
||||
ranks_tflops = {}
|
||||
for n in range(NUM_STEPS):
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
for MODEL_TYPE in "gpt2_medium"; do
|
||||
for BATCH_SIZE in 16; do
|
||||
for GPUNUM in 1 2 4 8; do
|
||||
for TPDEGREE in 1 2 4 8; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
echo "****************** Begin ***************************"
|
||||
echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}"
|
||||
MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./gemini/run_gemini.sh
|
||||
echo "****************** Finished ***************************"
|
||||
echo ""
|
||||
echo ""
|
||||
for DISTPLAN in "colossalai"; do
|
||||
for BATCH_SIZE in 16; do
|
||||
for GPUNUM in 1 2 4 8; do
|
||||
for TPDEGREE in 1 2 4 8; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
echo "****************** Begin ***************************"
|
||||
echo "+ benchmrking MODEL ${MODEL_TYPE} DISTPLAN ${DISTPLAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}"
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./run_gemini.sh
|
||||
echo "****************** Finished ***************************"
|
||||
echo ""
|
||||
echo ""
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
|
|
@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True):
|
|||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_30b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_40b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def model_builder(model_size: str) -> callable:
|
||||
if model_size == "gpt2_medium":
|
||||
return gpt2_medium
|
||||
|
@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable:
|
|||
return gpt2_20b
|
||||
elif model_size == "gpt2_24b":
|
||||
return gpt2_24b
|
||||
elif model_size == "gpt2_30b":
|
||||
return gpt2_30b
|
||||
elif model_size == "gpt2_40b":
|
||||
return gpt2_40b
|
||||
else:
|
||||
raise TypeError(f"model_builder {model_size}")
|
||||
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
|
@ -1,15 +1,15 @@
|
|||
set -x
|
||||
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
|
||||
export DISTPAN=${DISTPAN:-"colossalai"}
|
||||
export DISTPLAN=${DISTPLAN:-"colossalai"}
|
||||
|
||||
# The following options only valid when DISTPAN="colossalai"
|
||||
# The following options only valid when DISTPLAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export TPDEGREE=${TPDEGREE:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
# export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
|
||||
mkdir -p gemini_logs
|
||||
|
@ -20,5 +20,6 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
|||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
--shardinit=${USE_SHARD_INIT} \
|
||||
--distplan=${DISTPAN} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||
--distplan=${DISTPLAN} \
|
||||
--train_step=${TRAIN_STEP} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
set -x
|
||||
$(cd `dirname $0`;pwd)
|
||||
export TRAIN_STEP=4
|
||||
|
||||
for MODEL_TYPE in "gpt2_medium"; do
|
||||
for DISTPLAN in "colossalai"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1 2; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for DISTPLAN in "zero1" "zero2"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
|
@ -65,6 +65,13 @@ def parse_args():
|
|||
default="gpt2_medium",
|
||||
help="model model scale",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_step",
|
||||
type=int,
|
||||
default=10,
|
||||
help="training iterations for test",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -180,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
|
||||
fp16_init_scale = 2**5
|
||||
gpu_margin_mem_ratio_for_auto = 0
|
||||
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
model = GeminiDDP(model,
|
||||
strict_ddp_mode=ddp_flag,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.n_embd,
|
||||
search_range_mb=64)
|
||||
search_range_mb=128)
|
||||
# configure the const policy
|
||||
if placement_policy == 'const':
|
||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||
|
@ -236,7 +244,8 @@ def main():
|
|||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
NUM_STEPS = 10
|
||||
NUM_STEPS = args.train_step
|
||||
|
||||
WARMUP_STEPS = 1
|
||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
|
||||
|
@ -270,14 +279,17 @@ def main():
|
|||
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, tp_pg)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# build a Gemini model and a highly optimized cpu optimizer
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement)
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||||
|
||||
if args.distplan.startswith("torch"):
|
||||
|
@ -288,12 +300,17 @@ def main():
|
|||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
elif args.distplan.startswith("zero"):
|
||||
partition_flag = args.distplan == "zero2"
|
||||
model = model.half()
|
||||
partition_flag = (args.distplan == "zero2")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
optimizer = LowLevelZeroOptimizer(optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True)
|
||||
|
||||
optimizer = LowLevelZeroOptimizer(
|
||||
optimizer,
|
||||
reduce_bucket_size=12 * 1024 * 1024,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue