mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk
commit
e532679c95
|
@ -20,6 +20,8 @@ body:
|
|||
A clear and concise description of what you expected to happen.
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
**Optional: Affiliation**
|
||||
Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
|
||||
placeholder: |
|
||||
A clear and concise description of what the bug is.
|
||||
validations:
|
||||
|
|
|
@ -17,6 +17,7 @@ body:
|
|||
**Expectation** What is your expected content about it?
|
||||
**Screenshots** If applicable, add screenshots to help explain your problem.
|
||||
**Suggestions** Tell us how we could improve the documentation.
|
||||
**Optional: Affiliation** Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
|
||||
placeholder: |
|
||||
A clear and concise description of the issue.
|
||||
validations:
|
||||
|
|
|
@ -22,6 +22,8 @@ body:
|
|||
If applicable, add screenshots to help explain your problem.
|
||||
**Suggest a potential alternative/fix**
|
||||
Tell us how we could improve this project.
|
||||
**Optional: Affiliation**
|
||||
Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
|
||||
placeholder: |
|
||||
A clear and concise description of your idea.
|
||||
validations:
|
||||
|
|
|
@ -13,6 +13,7 @@ body:
|
|||
- Bumping a critical dependency's major version;
|
||||
- A significant improvement in user-friendliness;
|
||||
- Significant refactor;
|
||||
- Optional: Affiliation/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
|
||||
- ...
|
||||
|
||||
Please note this is not for feature request or bug template; such action could make us identify the issue wrongly and close it without doing anything.
|
||||
|
@ -43,4 +44,4 @@ body:
|
|||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
Thanks for contributing 🎉!
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
addReviewers: true
|
||||
|
||||
addAssignees: author
|
||||
|
||||
numberOfReviewers: 1
|
||||
|
||||
reviewers:
|
||||
- frankleeeee
|
||||
- kurisusnowdeng
|
|
@ -1,18 +0,0 @@
|
|||
name: Assign Reviewers for Team
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
assign_reviewer:
|
||||
name: Assign Reviewer for PR
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.pull_request.draft == false && github.base_ref == 'main'
|
||||
&& github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
&& toJson(github.event.pull_request.requested_reviewers) == '[]'
|
||||
steps:
|
||||
- uses: kentaro-m/auto-assign-action@v1.2.1
|
||||
with:
|
||||
configuration-path: '.github/reviewer_list.yml'
|
|
@ -0,0 +1,130 @@
|
|||
name: Test Example
|
||||
on:
|
||||
pull_request:
|
||||
# any change in the examples folder will trigger check for the corresponding example.
|
||||
paths:
|
||||
- 'examples/**'
|
||||
# run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
|
||||
schedule:
|
||||
- cron: '0 16 * * 6'
|
||||
|
||||
jobs:
|
||||
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
|
||||
detect-changed-example:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
||||
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
|
||||
name: Detect changed example files
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Get all changed example files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v35
|
||||
# Using this can trigger action each time a PR is submitted.
|
||||
with:
|
||||
since_last_remote_commit: true
|
||||
- name: setup matrix
|
||||
id: setup-matrix
|
||||
run: |
|
||||
changedFileName=""
|
||||
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
|
||||
changedFileName="${file}:${changedFileName}"
|
||||
done
|
||||
echo "$changedFileName was changed"
|
||||
res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName`
|
||||
echo "All changed examples are $res"
|
||||
|
||||
if [ "$x" = "[]" ]; then
|
||||
echo "anyChanged=false" >> $GITHUB_OUTPUT
|
||||
echo "matrix=null" >> $GITHUB_OUTPUT
|
||||
else
|
||||
dirs=$( IFS=',' ; echo "${res[*]}" )
|
||||
echo "anyChanged=true" >> $GITHUB_OUTPUT
|
||||
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# If no file is changed, it will prompt an error and shows the matrix do not have value.
|
||||
check-changed-example:
|
||||
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
|
||||
name: Test the changed example
|
||||
needs: detect-changed-example
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -v .
|
||||
- name: Test the example
|
||||
run: |
|
||||
example_dir=${{ matrix.directory }}
|
||||
cd "${PWD}/examples/${example_dir}"
|
||||
bash test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
|
||||
# This is for all files' weekly check. Specifically, this job is to find all the directories.
|
||||
matrix_preparation:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
|
||||
name: Prepare matrix for weekly check
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: setup matrix
|
||||
id: setup-matrix
|
||||
run: |
|
||||
res=`python .github/workflows/scripts/example_checks/check_example_weekly.py`
|
||||
all_loc=$( IFS=',' ; echo "${res[*]}" )
|
||||
echo "Found the examples: $all_loc"
|
||||
echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT
|
||||
|
||||
weekly_check:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
|
||||
name: Weekly check all examples
|
||||
needs: matrix_preparation
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -v .
|
||||
- name: Traverse all files
|
||||
run: |
|
||||
example_dir=${{ matrix.diretory }}
|
||||
echo "Testing ${example_dir} now"
|
||||
cd "${PWD}/examples/${example_dir}"
|
||||
bash test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
|
@ -1,17 +1,45 @@
|
|||
name: Build
|
||||
|
||||
on:
|
||||
on:
|
||||
pull_request:
|
||||
types: [synchronize, labeled]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build and Test Colossal-AI
|
||||
detect:
|
||||
name: Detect kernel-related file change
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
|
||||
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
|
||||
outputs:
|
||||
changedFiles: ${{ steps.find-changed-files.outputs.changedFiles }}
|
||||
anyChanged: ${{ steps.find-changed-files.outputs.any_changed }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Find the changed files
|
||||
id: find-changed-files
|
||||
uses: tj-actions/changed-files@v35
|
||||
with:
|
||||
since_last_remote_commit: true
|
||||
files: |
|
||||
op_builder/**
|
||||
colossalai/kernel/**
|
||||
setup.py
|
||||
- name: List changed files
|
||||
run: |
|
||||
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
|
||||
echo "$file was changed"
|
||||
done
|
||||
|
||||
|
||||
build:
|
||||
name: Build and Test Colossal-AI
|
||||
needs: detect
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.11.0-11.3.0
|
||||
|
@ -23,27 +51,38 @@ jobs:
|
|||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
conda install cmake
|
||||
pip install -r requirements.txt
|
||||
pip install -v .
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Install Colossal-AI
|
||||
|
||||
- name: Restore cache
|
||||
if: needs.detect.outputs.anyChanged != 'true'
|
||||
run: |
|
||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||
pip install -r requirements/requirements.txt
|
||||
pip install -v -e .
|
||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
|
||||
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
|
||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
CUDA_EXT=1 pip install -v -e .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest tests
|
||||
PYTHONPATH=$PWD pytest --cov=. --cov-report lcov tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
NCCL_SHM_DISABLE: 1
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
||||
- name: Store Cache
|
||||
run: |
|
||||
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
|
||||
cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||
|
|
|
@ -2,7 +2,7 @@ name: Build on 8 GPUs
|
|||
|
||||
on:
|
||||
schedule:
|
||||
# run at 00:00 of every Sunday
|
||||
# run at 00:00 of every Sunday
|
||||
- cron: '0 0 * * *'
|
||||
workflow_dispatch:
|
||||
|
||||
|
@ -30,13 +30,11 @@ jobs:
|
|||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Install Colossal-AI
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||
pip install -r requirements/requirements.txt
|
||||
pip install -v -e .
|
||||
CUDA_EXT=1 pip install -v -e .
|
||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
|
||||
pip install -r requirements/requirements-test.txt
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
|
@ -45,4 +43,3 @@ jobs:
|
|||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
|
@ -70,7 +70,7 @@ jobs:
|
|||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Install Colossal-AI
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -r requirements/requirements.txt
|
||||
pip install -v --no-cache-dir .
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
name: Manual Test Example
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
example_directory:
|
||||
type: string
|
||||
description: example directory, separated by space. For example, language/gpt, images/vit. Simply input language or simply gpt does not work.
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
matrix_preparation:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
name: Check the examples user want
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up matrix
|
||||
id: set-matrix
|
||||
env:
|
||||
check_dir: ${{ inputs.example_directory }}
|
||||
run: |
|
||||
res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir`
|
||||
if [ res == "failure" ];then
|
||||
exit -1
|
||||
fi
|
||||
dirs="[${check_dir}]"
|
||||
echo "Testing examples in $dirs"
|
||||
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
|
||||
|
||||
test_example:
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
name: Manually check example files
|
||||
needs: manual_check_matrix_preparation
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
pip install -v .
|
||||
- name: Test the example
|
||||
run: |
|
||||
dir=${{ matrix.directory }}
|
||||
echo "Testing ${dir} now"
|
||||
cd "${PWD}/examples/${dir}"
|
||||
bash test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.7.12'
|
||||
python-version: '3.8.14'
|
||||
- name: generate draft
|
||||
id: generate_draft
|
||||
run: |
|
||||
|
@ -42,4 +42,3 @@ jobs:
|
|||
body_path: ${{ steps.generate_draft.outputs.path }}
|
||||
draft: True
|
||||
prerelease: false
|
||||
|
|
@ -64,9 +64,21 @@ jobs:
|
|||
- name: Copy scripts and checkout
|
||||
run: |
|
||||
cp -r ./.github/workflows/scripts/* ./
|
||||
|
||||
# link the cache diretories to current path
|
||||
ln -s /github/home/conda_pkgs ./conda_pkgs
|
||||
ln -s /github/home/pip_wheels ./pip_wheels
|
||||
|
||||
# set the conda package path
|
||||
echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc
|
||||
|
||||
# set safe directory
|
||||
git config --global --add safe.directory /__w/ColossalAI/ColossalAI
|
||||
|
||||
# check out
|
||||
git checkout $git_ref
|
||||
|
||||
# get cub package for cuda 10.2
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
env:
|
||||
|
|
|
@ -18,23 +18,17 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
- name: Build Docker
|
||||
id: build
|
||||
run: |
|
||||
version=$(cat version.txt)
|
||||
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker
|
||||
tag=hpcaitech/colossalai:$version
|
||||
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker
|
||||
echo "tag=${tag}" >> $GITHUB_OUTPUT
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
|
||||
with:
|
||||
images: hpcaitech/colossalai
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
- name: Push Docker image
|
||||
run: |
|
||||
docker push ${{ steps.build.outputs.tag }}
|
||||
|
|
|
@ -1,74 +1,29 @@
|
|||
name: Release bdist wheel for Nightly versions
|
||||
name: Publish Nightly Version to PyPI
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# run at 00:00 of every Sunday
|
||||
- cron: '0 0 * * 6'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
matrix_preparation:
|
||||
name: Prepare Container List
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- id: set-matrix
|
||||
run: |
|
||||
matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]"
|
||||
echo $matrix
|
||||
echo "::set-output name=matrix::{\"container\":$(echo $matrix)}"
|
||||
schedule:
|
||||
- cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
|
||||
|
||||
build:
|
||||
name: Release bdist wheels
|
||||
needs: matrix_preparation
|
||||
if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
|
||||
runs-on: [self-hosted, gpu]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
# cub is for cuda 10.2
|
||||
- name: Copy scripts and checkout
|
||||
run: |
|
||||
cp -r ./.github/workflows/scripts/* ./
|
||||
ln -s /github/home/pip_wheels ./pip_wheels
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
- name: Build bdist wheel
|
||||
run: |
|
||||
pip install beautifulsoup4 requests packaging
|
||||
python ./build_colossalai_wheel.py --nightly
|
||||
- name: 🚀 Deploy
|
||||
uses: garygrossgarten/github-action-scp@release
|
||||
with:
|
||||
local: all_dist
|
||||
remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
|
||||
host: ${{ secrets.PRIVATE_PYPI_HOST }}
|
||||
username: ${{ secrets.PRIVATE_PYPI_USER }}
|
||||
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
|
||||
remove_old_build:
|
||||
name: Remove old nightly build
|
||||
jobs:
|
||||
build-n-publish:
|
||||
if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI'
|
||||
name: Build and publish Python 🐍 distributions 📦 to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: executing remote ssh commands using password
|
||||
uses: appleboy/ssh-action@master
|
||||
env:
|
||||
BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
|
||||
with:
|
||||
host: ${{ secrets.PRIVATE_PYPI_HOST }}
|
||||
username: ${{ secrets.PRIVATE_PYPI_USER }}
|
||||
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
|
||||
envs: BUILD_DIR
|
||||
script: |
|
||||
cd $BUILD_DIR
|
||||
find . -type f -mtime +0 -exec rm -f {} +
|
||||
script_stop: true
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.8.14'
|
||||
|
||||
- run: NIGHTLY=1 python setup.py sdist build
|
||||
|
||||
# publish to PyPI if executed on the main branch
|
||||
- name: Publish package to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
verbose: true
|
||||
|
|
|
@ -1,21 +1,29 @@
|
|||
name: Publish to PyPI
|
||||
|
||||
on: workflow_dispatch
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'version.txt'
|
||||
types:
|
||||
- closed
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
if: github.ref_name == 'main' && github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
|
||||
if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' && github.event.pull_request.merged == true && github.base_ref == 'main'
|
||||
name: Build and publish Python 🐍 distributions 📦 to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.7.12'
|
||||
python-version: '3.8.14'
|
||||
|
||||
- run: python setup.py sdist build
|
||||
|
||||
# publish to PyPI if executed on the main branch
|
||||
# publish to Test PyPI if executed on the develop branch
|
||||
- name: Publish package to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
|
@ -1,25 +0,0 @@
|
|||
name: Publish to Test PyPI
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
|
||||
name: Build and publish Python 🐍 distributions 📦 to Test PyPI
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.7.12'
|
||||
- run: python setup.py sdist build
|
||||
# publish to PyPI if executed on the main branch
|
||||
# publish to Test PyPI if executed on the develop branch
|
||||
- name: Publish package to Test PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
||||
repository_url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
|
@ -1,12 +1,13 @@
|
|||
from filecmp import cmp
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from packaging import version
|
||||
from filecmp import cmp
|
||||
from functools import cmp_to_key
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from packaging import version
|
||||
|
||||
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
|
||||
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
|
||||
CUDA_HOME = os.environ['CUDA_HOME']
|
||||
|
|
|
@ -18,7 +18,7 @@ if [ $1 == "pip" ]
|
|||
then
|
||||
wget -nc -q -O ./pip_wheels/$filename $url
|
||||
pip install ./pip_wheels/$filename
|
||||
|
||||
|
||||
elif [ $1 == 'conda' ]
|
||||
then
|
||||
conda install pytorch==$torch_version cudatoolkit=$cuda_version $flags
|
||||
|
@ -34,8 +34,9 @@ fi
|
|||
|
||||
python setup.py bdist_wheel
|
||||
mv ./dist/* ./all_dist
|
||||
# must remove build to enable compilation for
|
||||
# cuda extension in the next build
|
||||
rm -rf ./build
|
||||
python setup.py clean
|
||||
conda deactivate
|
||||
conda env remove -n $python_version
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def check_inputs(input_list):
|
||||
for path in input_list:
|
||||
real_path = os.path.join('examples', path)
|
||||
if not os.path.exists(real_path):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(",")
|
||||
is_correct = check_inputs(name_list)
|
||||
|
||||
if is_correct:
|
||||
print('success')
|
||||
else:
|
||||
print('failure')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,37 @@
|
|||
import os
|
||||
|
||||
|
||||
def show_files(path, all_files):
|
||||
# Traverse all the folder/file in current directory
|
||||
file_list = os.listdir(path)
|
||||
# Determine the element is folder or file. If file, pass it into list, if folder, recurse.
|
||||
for file_name in file_list:
|
||||
# Get the abs directory using os.path.join() and store into cur_path.
|
||||
cur_path = os.path.join(path, file_name)
|
||||
# Determine whether folder
|
||||
if os.path.isdir(cur_path):
|
||||
show_files(cur_path, all_files)
|
||||
else:
|
||||
all_files.append(cur_path)
|
||||
return all_files
|
||||
|
||||
|
||||
def join(input_list, sep=None):
|
||||
return (sep or ' ').join(input_list)
|
||||
|
||||
|
||||
def main():
|
||||
contents = show_files('examples/', [])
|
||||
all_loc = []
|
||||
for file_loc in contents:
|
||||
split_loc = file_loc.split('/')
|
||||
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
|
||||
if len(split_loc) >= 4:
|
||||
re_loc = '/'.join(split_loc[1:3])
|
||||
if re_loc not in all_loc:
|
||||
all_loc.append(re_loc)
|
||||
print(all_loc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,24 @@
|
|||
import argparse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(":")
|
||||
folder_need_check = set()
|
||||
for loc in name_list:
|
||||
# Find only the sub-sub-folder of 'example' folder
|
||||
# the examples folder structure is like
|
||||
# - examples
|
||||
# - area
|
||||
# - application
|
||||
# - file
|
||||
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
|
||||
folder_need_check.add('/'.join(loc.split("/")[1:3]))
|
||||
# Output the result using print. Then the shell can get the values.
|
||||
print(list(folder_need_check))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -2,9 +2,10 @@
|
|||
# coding: utf-8
|
||||
|
||||
import argparse
|
||||
import requests
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
|
||||
COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
|
||||
TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
name: Synchronize Submodule
|
||||
|
||||
on:
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
@ -27,11 +27,11 @@ jobs:
|
|||
|
||||
- name: Commit update
|
||||
run: |
|
||||
git config --global user.name 'github-actions'
|
||||
git config --global user.email 'github-actions@github.com'
|
||||
git config --global user.name 'github-actions'
|
||||
git config --global user.email 'github-actions@github.com'
|
||||
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
|
||||
git commit -am "Automated submodule synchronization"
|
||||
|
||||
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v3
|
||||
with:
|
||||
|
@ -43,4 +43,3 @@ jobs:
|
|||
assignees: ${{ github.actor }}
|
||||
delete-branch: true
|
||||
branch: create-pull-request/patch-sync-submodule
|
||||
|
|
@ -134,10 +134,23 @@ dmypy.json
|
|||
.vscode/
|
||||
|
||||
# macos
|
||||
.DS_Store
|
||||
*.DS_Store
|
||||
#data/
|
||||
|
||||
docs/.build
|
||||
|
||||
# pytorch checkpoint
|
||||
*.pt
|
||||
*.pt
|
||||
|
||||
# ignore version.py generated by setup.py
|
||||
colossalai/version.py
|
||||
|
||||
# ignore any kernel build files
|
||||
.o
|
||||
.so
|
||||
|
||||
# ignore python interface defition file
|
||||
.pyi
|
||||
|
||||
# ignore coverage test file
|
||||
converage.lcov
|
||||
|
|
|
@ -27,4 +27,4 @@ sphinx:
|
|||
python:
|
||||
install:
|
||||
- requirements: requirements/requirements.txt
|
||||
- requirements: docs/requirements.txt
|
||||
- requirements: docs/requirements.txt
|
||||
|
|
4
LICENSE
4
LICENSE
|
@ -1,4 +1,4 @@
|
|||
Copyright 2021- The Colossal-ai Authors. All rights reserved.
|
||||
Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
@ -187,7 +187,7 @@ Copyright 2021- The Colossal-ai Authors. All rights reserved.
|
|||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2021- HPC-AI Technology Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
include *.txt README.md
|
||||
recursive-include requirements *.txt
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
|
||||
recursive-include op_builder *.py
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
# Colossal-AI
|
||||
<div id="top" align="center">
|
||||
|
||||
[](https://www.colossalai.org/)
|
||||
[](https://www.colossalai.org/)
|
||||
|
||||
Colossal-AI: 一个面向大模型时代的通用深度学习系统
|
||||
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
|
||||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
|
||||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
|
||||
|
||||
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||
|
@ -22,41 +22,50 @@
|
|||
|
||||
</div>
|
||||
|
||||
## 新闻
|
||||
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
|
||||
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
|
||||
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
|
||||
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
|
||||
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
|
||||
|
||||
|
||||
## 目录
|
||||
<ul>
|
||||
<li><a href="#为何选择-Colossal-AI">为何选择 Colossal-AI</a> </li>
|
||||
<li><a href="#特点">特点</a> </li>
|
||||
<li>
|
||||
<a href="#并行训练样例展示">并行训练样例展示</a>
|
||||
<a href="#并行训练样例展示">并行训练样例展示</a>
|
||||
<ul>
|
||||
<li><a href="#ViT">ViT</a></li>
|
||||
<li><a href="#GPT-3">GPT-3</a></li>
|
||||
<li><a href="#GPT-2">GPT-2</a></li>
|
||||
<li><a href="#BERT">BERT</a></li>
|
||||
<li><a href="#PaLM">PaLM</a></li>
|
||||
<li><a href="#OPT">OPT</a></li>
|
||||
<li><a href="#ViT">ViT</a></li>
|
||||
<li><a href="#推荐系统模型">推荐系统模型</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
|
||||
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-2-Single">GPT-2</a></li>
|
||||
<li><a href="#PaLM-Single">PaLM</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
|
||||
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-3-Inference">GPT-3</a></li>
|
||||
<li><a href="#OPT-Serving">1750亿参数OPT在线推理服务</a></li>
|
||||
<li><a href="#BLOOM-Inference">1750亿参数 BLOOM</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
|
||||
<ul>
|
||||
<li><a href="#xTrimoMultimer">xTrimoMultimer: 蛋白质单体与复合物结构预测</a></li>
|
||||
<li><a href="#AIGC">AIGC: 加速 Stable Diffusion</a></li>
|
||||
<li><a href="#生物医药">生物医药: 加速AlphaFold蛋白质结构预测</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
|
@ -69,11 +78,6 @@
|
|||
<li><a href="#使用-Docker">使用 Docker</a></li>
|
||||
<li><a href="#社区">社区</a></li>
|
||||
<li><a href="#做出贡献">做出贡献</a></li>
|
||||
<li><a href="#快速预览">快速预览</a></li>
|
||||
<ul>
|
||||
<li><a href="#几行代码开启分布式训练">几行代码开启分布式训练</a></li>
|
||||
<li><a href="#构建一个简单的2维并行模型">构建一个简单的2维并行模型</a></li>
|
||||
</ul>
|
||||
<li><a href="#引用我们">引用我们</a></li>
|
||||
</ul>
|
||||
|
||||
|
@ -98,6 +102,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
- 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行
|
||||
- [序列并行](https://arxiv.org/abs/2105.13120)
|
||||
- [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054)
|
||||
- [自动并行](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
|
||||
- 异构内存管理
|
||||
- [PatrickStar](https://arxiv.org/abs/2108.05818)
|
||||
- 使用友好
|
||||
|
@ -105,16 +110,11 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
- 推理
|
||||
- [Energon-AI](https://github.com/hpcaitech/EnergonAI)
|
||||
- Colossal-AI 成功案例
|
||||
- [xTrimoMultimer: 蛋白质单体与复合物结构预测](https://github.com/biomap-research/xTrimoMultimer)
|
||||
- 生物医药: [FastFold](https://github.com/hpcaitech/FastFold) 加速蛋白质结构预测 AlphaFold 训练与推理
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
## 并行训练样例展示
|
||||
### ViT
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
|
||||
</p>
|
||||
|
||||
- 14倍批大小和5倍训练速度(张量并行=64)
|
||||
|
||||
### GPT-3
|
||||
<p align="center">
|
||||
|
@ -131,7 +131,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/(updated)GPT-2.png" width=800>
|
||||
|
||||
- 用相同的硬件训练24倍大的模型
|
||||
- 超3倍的吞吐量
|
||||
- 超3倍的吞吐量
|
||||
|
||||
### BERT
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BERT.png" width=800/>
|
||||
|
@ -145,10 +145,16 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
|
||||
|
||||
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型,由于完全公开了预训练参数权重,因此促进了下游任务和应用部署的发展。
|
||||
- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
|
||||
- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
|
||||
|
||||
请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI-Examples) 以了解详情。
|
||||
|
||||
### ViT
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
|
||||
</p>
|
||||
|
||||
- 14倍批大小和5倍训练速度(张量并行=64)
|
||||
|
||||
### 推荐系统模型
|
||||
- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), 使用软件Cache实现Embeddings,用更少GPU显存训练更大的模型。
|
||||
|
@ -178,7 +184,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
|
||||
- 用相同的硬件训练34倍大的模型
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
|
||||
## 推理 (Energon-AI) 样例展示
|
||||
|
@ -195,23 +201,82 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
|
||||
- [OPT推理服务](https://service.colossalai.org/opt): 无需注册,免费体验1750亿参数OPT在线推理服务
|
||||
|
||||
<p id="BLOOM-Inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
|
||||
</p>
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): 降低1750亿参数BLOOM模型部署推理成本超10倍
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
## Colossal-AI 成功案例
|
||||
|
||||
### xTrimoMultimer: 蛋白质单体与复合物结构预测
|
||||
### AIGC
|
||||
加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion)
|
||||
|
||||
<p id="diffusion_train" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20v2.png" width=800/>
|
||||
</p>
|
||||
|
||||
- [训练](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): 减少5.6倍显存消耗,硬件成本最高降低46倍(从A100到RTX3060)
|
||||
|
||||
<p id="diffusion_demo" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/DreamBooth.png" width=800/>
|
||||
</p>
|
||||
|
||||
- [DreamBooth微调](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): 仅需3-5张目标主题图像个性化微调
|
||||
|
||||
<p id="inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20Inference.jpg" width=800/>
|
||||
</p>
|
||||
|
||||
- [推理](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): GPU推理显存消耗降低2.5倍
|
||||
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
### 生物医药
|
||||
|
||||
加速 [AlphaFold](https://alphafold.ebi.ac.uk/) 蛋白质结构预测
|
||||
|
||||
<p id="FastFold" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/FastFold.jpg" width=800/>
|
||||
</p>
|
||||
|
||||
- [FastFold](https://github.com/hpcaitech/FastFold): 加速AlphaFold训练与推理、数据前处理、推理序列长度超过10000残基
|
||||
|
||||
<p id="xTrimoMultimer" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTM_Prediction.jpg" width=380/>
|
||||
<p></p>
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTrimoMultimer_Table.jpg" width=800/>
|
||||
</p>
|
||||
|
||||
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): 11倍加速蛋白质单体与复合物结构预测
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
## 安装
|
||||
|
||||
### 从PyPI安装
|
||||
|
||||
您可以用下面的命令直接从PyPI上下载并安装Colossal-AI。我们默认不会安装PyTorch扩展包
|
||||
|
||||
```bash
|
||||
pip install colossalai
|
||||
```
|
||||
|
||||
但是,如果你想在安装时就直接构建PyTorch扩展,您可以设置环境变量`CUDA_EXT=1`.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
**否则,PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。**
|
||||
|
||||
与此同时,我们也每周定时发布Nightly版本,这能让你提前体验到新的feature和bug fix。你可以通过以下命令安装Nightly版本。
|
||||
|
||||
```bash
|
||||
pip install colossalai-nightly
|
||||
```
|
||||
|
||||
### 从官方安装
|
||||
|
||||
您可以访问我们[下载](https://www.colossalai.org/download)页面来安装Colossal-AI,在这个页面上发布的版本都预编译了CUDA扩展。
|
||||
|
@ -231,10 +296,10 @@ pip install -r requirements/requirements.txt
|
|||
pip install .
|
||||
```
|
||||
|
||||
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装):
|
||||
我们默认在`pip install`时不安装PyTorch扩展,而是在运行时临时编译,如果你想要提前安装这些扩展的话(在使用融合优化器时会用到),可以使用一下命令。
|
||||
|
||||
```shell
|
||||
NO_CUDA_EXT=1 pip install .
|
||||
CUDA_EXT=1 pip install .
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
@ -283,31 +348,6 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
|
|||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
## 快速预览
|
||||
|
||||
### 几行代码开启分布式训练
|
||||
|
||||
```python
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode='2.5d', depth = 1, size=4)
|
||||
)
|
||||
```
|
||||
|
||||
### 几行代码开启异构训练
|
||||
|
||||
```python
|
||||
zero = dict(
|
||||
model_config=dict(
|
||||
tensor_placement_policy='auto',
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=True
|
||||
),
|
||||
optimizer_config=dict(initial_scale=2**5, gpu_margin_mem_ratio=0.2)
|
||||
)
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
||||
## 引用我们
|
||||
|
||||
|
@ -320,4 +360,4 @@ zero = dict(
|
|||
}
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
<p align="right">(<a href="#top">返回顶端</a>)</p>
|
||||
|
|
163
README.md
163
README.md
|
@ -1,14 +1,14 @@
|
|||
# Colossal-AI
|
||||
<div id="top" align="center">
|
||||
|
||||
[](https://www.colossalai.org/)
|
||||
[](https://www.colossalai.org/)
|
||||
|
||||
Colossal-AI: A Unified Deep Learning System for Big Model Era
|
||||
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> Paper </a> |
|
||||
<a href="https://www.colossalai.org/"> Documentation </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> Examples </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
||||
<h3> <a href="https://arxiv.org/abs/2110.14883"> Paper </a> |
|
||||
<a href="https://www.colossalai.org/"> Documentation </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> Examples </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> Blog </a></h3>
|
||||
|
||||
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||
|
@ -17,46 +17,55 @@
|
|||
[](https://huggingface.co/hpcai-tech)
|
||||
[](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
||||
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
|
||||
|
||||
|
||||
|
||||
| [English](README.md) | [中文](README-zh-Hans.md) |
|
||||
|
||||
</div>
|
||||
|
||||
## Latest News
|
||||
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
|
||||
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
|
||||
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
|
||||
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
|
||||
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
|
||||
|
||||
## Table of Contents
|
||||
<ul>
|
||||
<li><a href="#Why-Colossal-AI">Why Colossal-AI</a> </li>
|
||||
<li><a href="#Features">Features</a> </li>
|
||||
<li>
|
||||
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
|
||||
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
|
||||
<ul>
|
||||
<li><a href="#ViT">ViT</a></li>
|
||||
<li><a href="#GPT-3">GPT-3</a></li>
|
||||
<li><a href="#GPT-2">GPT-2</a></li>
|
||||
<li><a href="#BERT">BERT</a></li>
|
||||
<li><a href="#PaLM">PaLM</a></li>
|
||||
<li><a href="#OPT">OPT</a></li>
|
||||
<li><a href="#ViT">ViT</a></li>
|
||||
<li><a href="#Recommendation-System-Models">Recommendation System Models</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#Single-GPU-Training-Demo">Single GPU Training Demo</a>
|
||||
<a href="#Single-GPU-Training-Demo">Single GPU Training Demo</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-2-Single">GPT-2</a></li>
|
||||
<li><a href="#PaLM-Single">PaLM</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#Inference-Energon-AI-Demo">Inference (Energon-AI) Demo</a>
|
||||
<a href="#Inference-Energon-AI-Demo">Inference (Energon-AI) Demo</a>
|
||||
<ul>
|
||||
<li><a href="#GPT-3-Inference">GPT-3</a></li>
|
||||
<li><a href="#OPT-Serving">OPT-175B Online Serving for Text Generation</a></li>
|
||||
<li><a href="#BLOOM-Inference">175B BLOOM</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI for Real World Applications</a>
|
||||
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI for Real World Applications</a>
|
||||
<ul>
|
||||
<li><a href="#xTrimoMultimer">xTrimoMultimer: Accelerating Protein Monomer and Multimer Structure Prediction</a></li>
|
||||
<li><a href="#AIGC">AIGC: Acceleration of Stable Diffusion</a></li>
|
||||
<li><a href="#Biomedicine">Biomedicine: Acceleration of AlphaFold Protein Structure</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>
|
||||
|
@ -69,11 +78,6 @@
|
|||
<li><a href="#Use-Docker">Use Docker</a></li>
|
||||
<li><a href="#Community">Community</a></li>
|
||||
<li><a href="#contributing">Contributing</a></li>
|
||||
<li><a href="#Quick-View">Quick View</a></li>
|
||||
<ul>
|
||||
<li><a href="#Start-Distributed-Training-in-Lines">Start Distributed Training in Lines</a></li>
|
||||
<li><a href="#Write-a-Simple-2D-Parallel-Model">Write a Simple 2D Parallel Model</a></li>
|
||||
</ul>
|
||||
<li><a href="#Cite-Us">Cite Us</a></li>
|
||||
</ul>
|
||||
|
||||
|
@ -100,8 +104,9 @@ distributed training and inference in a few lines.
|
|||
- 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism
|
||||
- [Sequence Parallelism](https://arxiv.org/abs/2105.13120)
|
||||
- [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054)
|
||||
- [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
|
||||
|
||||
- Heterogeneous Memory Management
|
||||
- Heterogeneous Memory Management
|
||||
- [PatrickStar](https://arxiv.org/abs/2108.05818)
|
||||
|
||||
- Friendly Usage
|
||||
|
@ -110,17 +115,11 @@ distributed training and inference in a few lines.
|
|||
- Inference
|
||||
- [Energon-AI](https://github.com/hpcaitech/EnergonAI)
|
||||
|
||||
- Colossal-AI in the Real World
|
||||
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): Accelerating Protein Monomer and Multimer Structure Prediction
|
||||
- Colossal-AI in the Real World
|
||||
- Biomedicine: [FastFold](https://github.com/hpcaitech/FastFold) accelerates training and inference of AlphaFold protein structure
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Parallel Training Demo
|
||||
### ViT
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
|
||||
</p>
|
||||
|
||||
- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64
|
||||
|
||||
### GPT-3
|
||||
<p align="center">
|
||||
|
@ -150,10 +149,17 @@ distributed training and inference in a few lines.
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
|
||||
|
||||
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights.
|
||||
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
|
||||
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
|
||||
|
||||
Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details.
|
||||
|
||||
### ViT
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
|
||||
</p>
|
||||
|
||||
- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64
|
||||
|
||||
### Recommendation System Models
|
||||
- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), utilize software cache to train larger embedding tables with a smaller GPU memory budget.
|
||||
|
||||
|
@ -198,26 +204,85 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
|
|||
|
||||
- [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
|
||||
|
||||
<p id="BLOOM-Inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
|
||||
</p>
|
||||
|
||||
- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 175-billion-parameter BLOOM by more than 10 times.
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Colossal-AI in the Real World
|
||||
|
||||
### xTrimoMultimer: Accelerating Protein Monomer and Multimer Structure Prediction
|
||||
### AIGC
|
||||
Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion).
|
||||
<p id="diffusion_train" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20v2.png" width=800/>
|
||||
</p>
|
||||
|
||||
- [Training](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce Stable Diffusion memory consumption by up to 5.6x and hardware cost by up to 46x (from A100 to RTX3060).
|
||||
|
||||
<p id="diffusion_demo" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/DreamBooth.png" width=800/>
|
||||
</p>
|
||||
|
||||
- [DreamBooth Fine-tuning](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): Personalize your model using just 3-5 images of the desired subject.
|
||||
|
||||
<p id="inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20Inference.jpg" width=800/>
|
||||
</p>
|
||||
|
||||
- [Inference](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce inference GPU memory consumption by 2.5x.
|
||||
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
### Biomedicine
|
||||
Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
|
||||
|
||||
<p id="FastFold" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/FastFold.jpg" width=800/>
|
||||
</p>
|
||||
|
||||
- [FastFold](https://github.com/hpcaitech/FastFold): accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues.
|
||||
|
||||
<p id="xTrimoMultimer" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTM_Prediction.jpg" width=380/>
|
||||
<p></p>
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTrimoMultimer_Table.jpg" width=800/>
|
||||
</p>
|
||||
|
||||
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x
|
||||
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x.
|
||||
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Installation
|
||||
|
||||
### Install from PyPI
|
||||
|
||||
You can easily install Colossal-AI with the following command. **By defualt, we do not build PyTorch extensions during installation.**
|
||||
|
||||
```bash
|
||||
pip install colossalai
|
||||
```
|
||||
|
||||
However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
**Otherwise, CUDA kernels will be built during runtime when you actually need it.**
|
||||
|
||||
We also keep release the nightly version to PyPI on a weekly basis. This allows you to access the unreleased features and bug fixes in the main branch.
|
||||
Installation can be made via
|
||||
|
||||
```bash
|
||||
pip install colossalai-nightly
|
||||
```
|
||||
|
||||
### Download From Official Releases
|
||||
|
||||
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions.
|
||||
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built PyTorch extensions.
|
||||
|
||||
|
||||
### Download From Source
|
||||
|
@ -228,17 +293,15 @@ You can visit the [Download](https://www.colossalai.org/download) page to downlo
|
|||
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||
cd ColossalAI
|
||||
|
||||
# install dependency
|
||||
pip install -r requirements/requirements.txt
|
||||
|
||||
# install colossalai
|
||||
pip install .
|
||||
```
|
||||
|
||||
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
|
||||
By default, we do not compile CUDA/C++ kernels. ColossalAI will build them during runtime.
|
||||
If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
|
||||
|
||||
```shell
|
||||
NO_CUDA_EXT=1 pip install .
|
||||
CUDA_EXT=1 pip install .
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
@ -289,32 +352,6 @@ Thanks so much to all of our amazing contributors!
|
|||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Quick View
|
||||
|
||||
### Start Distributed Training in Lines
|
||||
|
||||
```python
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode='2.5d', depth = 1, size=4)
|
||||
)
|
||||
```
|
||||
|
||||
### Start Heterogeneous Training in Lines
|
||||
|
||||
```python
|
||||
zero = dict(
|
||||
model_config=dict(
|
||||
tensor_placement_policy='auto',
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=True
|
||||
),
|
||||
optimizer_config=dict(initial_scale=2**5, gpu_margin_mem_ratio=0.2)
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
||||
## Cite Us
|
||||
|
||||
|
|
|
@ -7,4 +7,11 @@ from .initialize import (
|
|||
launch_from_torch,
|
||||
)
|
||||
|
||||
__version__ = '0.1.11rc1'
|
||||
try:
|
||||
# .version will be created by setup.py
|
||||
from .version import __version__
|
||||
except ModuleNotFoundError:
|
||||
# this will only happen if the user did not run `pip install`
|
||||
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
|
||||
__version__ = '0.0.0'
|
||||
print('please install Colossal-AI from https://www.colossalai.org/download or from source')
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .apex_amp import ApexAMPOptimizer
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .apex_amp import ApexAMPOptimizer
|
||||
|
||||
|
||||
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
r"""A helper function to wrap training components with Apex AMP modules
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import inspect
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
|
||||
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
|
||||
|
||||
|
||||
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
|
|
|
@ -3,24 +3,33 @@
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
|
||||
|
||||
from ._utils import has_inf_or_nan, zero_gard_by_list
|
||||
from .grad_scaler import BaseGradScaler
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier)
|
||||
from torch.distributed import ProcessGroup
|
||||
from .grad_scaler import BaseGradScaler
|
||||
from ._utils import has_inf_or_nan, zero_gard_by_list
|
||||
fused_optim = None
|
||||
|
||||
__all__ = ['FP16Optimizer']
|
||||
|
||||
|
||||
def load_fused_optim():
|
||||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
|
||||
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
"""
|
||||
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
|
||||
|
@ -33,7 +42,9 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
|||
if overflow_buf:
|
||||
overflow_buf.fill_(0)
|
||||
# Scaling with factor `1.0` is equivalent to copy.
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
global fused_optim
|
||||
load_fused_optim()
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
else:
|
||||
for this_, that_ in zip(this, that):
|
||||
that_.copy_(this_)
|
||||
|
@ -41,7 +52,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
|||
|
||||
class FP16Optimizer(Optimizer):
|
||||
"""Float16 optimizer for fp16 and bf16 data types.
|
||||
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
|
||||
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
|
||||
|
@ -73,8 +84,8 @@ class FP16Optimizer(Optimizer):
|
|||
|
||||
# get process group
|
||||
def _get_process_group(parallel_mode):
|
||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA):
|
||||
return gpc.get_group(ParallelMode.DATA)
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode):
|
||||
return gpc.get_group(parallel_mode)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -150,6 +161,12 @@ class FP16Optimizer(Optimizer):
|
|||
f"==========================================",
|
||||
ranks=[0])
|
||||
|
||||
@property
|
||||
def max_norm(self):
|
||||
"""Returns the maximum norm of gradient clipping.
|
||||
"""
|
||||
return self._clip_grad_max_norm
|
||||
|
||||
@property
|
||||
def grad_scaler(self):
|
||||
"""Returns the gradient scaler.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
__all__ = ['BaseGradScaler']
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
__all__ = ['DynamicGradScaler']
|
||||
|
||||
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from typing import Any
|
||||
from torch.optim import Optimizer
|
||||
from torch.distributed import ReduceOp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
|
||||
|
||||
|
@ -40,7 +43,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
|
|||
return self.optim.step()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
pass
|
||||
if self.optim.max_norm == max_norm:
|
||||
return
|
||||
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
|
||||
"If you have supplied clip_grad_norm in the amp_config, "
|
||||
"executing the method clip_grad_norm is not allowed.")
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from colossalai.context import Config
|
||||
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer
|
||||
|
||||
|
||||
def convert_to_torch_amp(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
|
|
|
@ -3,16 +3,18 @@
|
|||
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
|
||||
# to support tensor parallel
|
||||
|
||||
import torch
|
||||
from collections import defaultdict, abc
|
||||
import warnings
|
||||
from collections import abc, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from packaging import version
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class _MultiDeviceReplicator(object):
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.cuda.amp as torch_amp
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
from ._grad_scaler import GradScaler
|
||||
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import clip_grad_norm_fp32
|
||||
|
||||
from ._grad_scaler import GradScaler
|
||||
|
||||
|
||||
class TorchAMPOptimizer(ColossalaiOptimizer):
|
||||
"""A wrapper class which integrate Pytorch AMP with an optimizer
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
from .ckpt_solver_chen import CheckpointSolverChen
|
||||
from .ckpt_solver_rotor import CheckpointSolverRotor
|
|
@ -0,0 +1,16 @@
|
|||
import os
|
||||
|
||||
from setuptools import Extension, setup
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_modules = [Extension(
|
||||
'rotorc',
|
||||
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
|
||||
)]
|
||||
|
||||
setup(
|
||||
name='rotor c extension',
|
||||
version='0.1',
|
||||
description='rotor c extension for faster dp computing',
|
||||
ext_modules=ext_modules,
|
||||
)
|
|
@ -0,0 +1,195 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||
runtime_apply,
|
||||
runtime_apply_for_iterable_object,
|
||||
runtime_comm_spec_apply,
|
||||
)
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
|
||||
|
||||
def _copy_output(src: Graph, dst: Graph):
|
||||
"""Copy the output node from src to dst"""
|
||||
for n_src, n_dst in zip(src.nodes, dst.nodes):
|
||||
if n_src.op == 'output':
|
||||
n_dst.meta = n_src.meta
|
||||
|
||||
|
||||
def _get_param_size(module: torch.nn.Module):
|
||||
"""Get the size of the parameters in the module"""
|
||||
return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()])
|
||||
|
||||
|
||||
class CheckpointSolverBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
free_memory: float = -1.0,
|
||||
requires_linearize: bool = False,
|
||||
cnode: List[str] = None,
|
||||
optim_multiplier: float = 1.0,
|
||||
):
|
||||
"""``CheckpointSolverBase`` class will integrate information provided by the components
|
||||
and use an existing solver to find a possible optimal strategies combination for target
|
||||
computing graph.
|
||||
|
||||
Existing Solvers:
|
||||
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
|
||||
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
free_memory (float): Memory constraint for the solution.
|
||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||
``torch.optim.Optimizer``. Default to 1.0.
|
||||
|
||||
Warnings:
|
||||
Meta information of the graph is required for any ``CheckpointSolver``.
|
||||
"""
|
||||
# super-dainiu: this graph is a temporary graph which can refer to
|
||||
# the owning module, but we will return another deepcopy of it after
|
||||
# the solver is executed.
|
||||
self.graph = deepcopy(graph)
|
||||
self.graph.owning_module = graph.owning_module
|
||||
_copy_output(graph, self.graph)
|
||||
self.graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
|
||||
# check if has meta information
|
||||
if any(len(node.meta) == 0 for node in self.graph.nodes):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
|
||||
)
|
||||
|
||||
# parameter memory = parameter size + optimizer extra weight storage
|
||||
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
|
||||
self.cnode = cnode
|
||||
self.requires_linearize = requires_linearize
|
||||
if self.requires_linearize:
|
||||
self.node_list = self._linearize_graph()
|
||||
else:
|
||||
self.node_list = self.get_node_list()
|
||||
|
||||
@abstractmethod
|
||||
def solve(self):
|
||||
"""Solve the checkpointing problem and return the solution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_node_list(self):
|
||||
"""Get the node list.
|
||||
"""
|
||||
return [[node] for node in self.graph.nodes]
|
||||
|
||||
def _linearize_graph(self) -> List[List[Node]]:
|
||||
"""Linearizing the graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
|
||||
Returns:
|
||||
List[List[Node]]: List of list, each inside list of Node presents
|
||||
the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
Do merge the inplace ops and shape-consistency ops into the previous node.
|
||||
"""
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
# unchanged throughout the whole model, it will be used several times by
|
||||
# different blocks of model, so that it is hard for us to linearize the graph
|
||||
# when we encounter those kinds of nodes. We let users to annotate some of the
|
||||
# input as common node, such as attention mask, and the followings are some of
|
||||
# the ops that could actually be seen as common nodes. With our common node prop,
|
||||
# we could find some of the "real" common nodes (e.g. the real attention mask
|
||||
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
||||
# nodes or it's op belongs to the following operations, we view this node as a
|
||||
# newly born common node.
|
||||
# List of target name that could be seen as common node
|
||||
common_ops = ["getattr", "getitem", "size"]
|
||||
|
||||
def _is_cop(target: Any) -> bool:
|
||||
"""Check if an op could be seen as common node
|
||||
|
||||
Args:
|
||||
target (Any): node target
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
if isinstance(target, str):
|
||||
return target in common_ops
|
||||
else:
|
||||
return target.__name__ in common_ops
|
||||
|
||||
def _is_sink() -> bool:
|
||||
"""Check if we can free all dependencies
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
return inplace
|
||||
|
||||
def _is_shape_consistency(n: Node):
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
||||
"""
|
||||
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
||||
map(_is_shape_consistency, n.users))
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
for name in self.cnode:
|
||||
try:
|
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"Common node {name} is not an input of the model."
|
||||
except StopIteration:
|
||||
raise ValueError(f"Common node name {name} not in graph.")
|
||||
|
||||
else:
|
||||
self.cnode = []
|
||||
|
||||
deps = {}
|
||||
node_list = []
|
||||
region = []
|
||||
|
||||
for n in self.graph.nodes:
|
||||
if n.op != "placeholder" and n.op != "output":
|
||||
for n_par in n.all_input_nodes:
|
||||
if n_par.op != "placeholder" and n_par.name not in self.cnode:
|
||||
deps[n_par] -= 1
|
||||
region.append(n)
|
||||
|
||||
# if the node could free all dependencies in graph
|
||||
# we could begin a new node
|
||||
if _is_sink():
|
||||
node_list.append(region)
|
||||
region = []
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||||
]) or _is_cop(n.target):
|
||||
self.cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
return node_list
|
|
@ -0,0 +1,87 @@
|
|||
import math
|
||||
from copy import deepcopy
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
|
||||
__all__ = ['CheckpointSolverChen']
|
||||
|
||||
|
||||
class CheckpointSolverChen(CheckpointSolverBase):
|
||||
|
||||
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||
|
||||
Usage:
|
||||
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using ``CheckpointSolverChen``:
|
||||
>>> solver = CheckpointSolverChen(gm.graph)
|
||||
>>> chen_graph = solver.solve()
|
||||
>>> gm.graph = chen_graph # set the graph to a new graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
|
||||
"""
|
||||
super().__init__(graph, 0, 0, True, cnode)
|
||||
self.num_grids = num_grids
|
||||
|
||||
def solve(self) -> Graph:
|
||||
"""Solve the checkpointing problem using Algorithm 3.
|
||||
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
"""
|
||||
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
ckpt = self.grid_search()
|
||||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
nodes = self.node_list[idx]
|
||||
for n in nodes:
|
||||
if n.op in checkpointable_op:
|
||||
n.meta['activation_checkpoint'] = i
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
"""
|
||||
ckpt_intv = []
|
||||
temp = 0
|
||||
x = 0
|
||||
y = 0
|
||||
prev_idx = 2
|
||||
for idx, nodes in enumerate(self.node_list):
|
||||
for n in nodes:
|
||||
n: Node
|
||||
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
|
||||
y = max(y, temp)
|
||||
if temp > b and idx > prev_idx:
|
||||
x += calculate_fwd_in(nodes[0])
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
return ckpt_intv, math.floor(math.sqrt(x * y))
|
||||
|
||||
def grid_search(self) -> Set:
|
||||
"""
|
||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
|
||||
"""
|
||||
_, b_approx = self.run_chen_greedy(0)
|
||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||
b_opt = math.inf
|
||||
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids):
|
||||
ckpt_intv, b_approx = self.run_chen_greedy(b)
|
||||
if b_approx < b_opt:
|
||||
b_opt = b_approx
|
||||
ckpt_opt = ckpt_intv
|
||||
return ckpt_opt
|
|
@ -0,0 +1,197 @@
|
|||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
long* PySequenceToLongArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
long* result = (long*)calloc(len + 1, sizeof(long));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyLong_AsLong(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
double* PySequenceToDoubleArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
double* result = (double*)calloc(len + 1, sizeof(double));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyFloat_AsDouble(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
long* getLongArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
long* result = PySequenceToLongArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
double* getDoubleArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
double* result = PySequenceToDoubleArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* computeTable(PyObject* self, PyObject* args) {
|
||||
PyObject* chainParam;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL;
|
||||
|
||||
double* ftime = getDoubleArray(chainParam, "ftime");
|
||||
if (!ftime) return NULL;
|
||||
|
||||
double* btime = getDoubleArray(chainParam, "btime");
|
||||
if (!btime) return NULL;
|
||||
|
||||
long* x = getLongArray(chainParam, "x");
|
||||
if (!x) return NULL;
|
||||
|
||||
long* xbar = getLongArray(chainParam, "xbar");
|
||||
if (!xbar) return NULL;
|
||||
|
||||
long* ftmp = getLongArray(chainParam, "btmp");
|
||||
if (!ftmp) return NULL;
|
||||
|
||||
long* btmp = getLongArray(chainParam, "btmp");
|
||||
if (!btmp) return NULL;
|
||||
|
||||
long chainLength = PyObject_Length(chainParam);
|
||||
if (!chainLength) return NULL;
|
||||
|
||||
#define COST_TABLE(m, i, l) \
|
||||
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
|
||||
(i) * (chainLength + 1) + (l)]
|
||||
double* costTable = (double*)calloc(
|
||||
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double));
|
||||
|
||||
#define BACK_PTR(m, i, l) \
|
||||
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
|
||||
(i) * (chainLength + 1) + (l)]
|
||||
long* backPtr = (long*)calloc(
|
||||
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chainLength; ++i)
|
||||
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
|
||||
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
|
||||
COST_TABLE(m, i, i) = ftime[i] + btime[i];
|
||||
else
|
||||
COST_TABLE(m, i, i) = INFINITY;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long d = 1; d <= chainLength; ++d) {
|
||||
for (long i = 0; i <= chainLength - d; ++i) {
|
||||
long idx = i + d;
|
||||
long mmin = x[idx + 1] + x[i + 1] + ftmp[i];
|
||||
if (idx > i + 1) {
|
||||
long maxCostFWD = 0;
|
||||
for (long j = i + 1; j < idx; j++) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]);
|
||||
}
|
||||
mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD);
|
||||
}
|
||||
if ((m >= mmin)) {
|
||||
long bestLeaf = -1;
|
||||
double sumFw = 0;
|
||||
double bestLeafCost = INFINITY;
|
||||
for (long j = i + 1; j <= idx; ++j) {
|
||||
sumFw += ftime[j - 1];
|
||||
if (m >= x[j]) {
|
||||
double cost = sumFw + COST_TABLE(m - x[j], j, idx) +
|
||||
COST_TABLE(m, i, j - 1);
|
||||
if (cost < bestLeafCost) {
|
||||
bestLeafCost = cost;
|
||||
bestLeaf = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= xbar[i + 1])
|
||||
chainCost =
|
||||
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
|
||||
if (bestLeafCost <= chainCost) {
|
||||
COST_TABLE(m, i, idx) = bestLeafCost;
|
||||
BACK_PTR(m, i, idx) = bestLeaf;
|
||||
} else {
|
||||
COST_TABLE(m, i, idx) = chainCost;
|
||||
BACK_PTR(m, i, idx) = -1;
|
||||
}
|
||||
} else
|
||||
COST_TABLE(m, i, idx) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
free(ftime);
|
||||
free(btime);
|
||||
free(x);
|
||||
free(xbar);
|
||||
free(ftmp);
|
||||
free(btmp);
|
||||
|
||||
PyObject* pyCostTable = PyList_New(mmax + 1);
|
||||
PyObject* pyBackPtr = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* pyCostTable_m = PyList_New(chainLength + 1);
|
||||
PyList_SET_ITEM(pyCostTable, m, pyCostTable_m);
|
||||
PyObject* pyBackPtr_m = PyList_New(chainLength + 1);
|
||||
PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m);
|
||||
for (long i = 0; i <= chainLength; ++i) {
|
||||
PyObject* pyCostTable_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i);
|
||||
PyObject* pyBackPtr_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i);
|
||||
for (long l = i; l <= chainLength; ++l) {
|
||||
PyObject* pyVar_l = PyLong_FromLong(l);
|
||||
PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l));
|
||||
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
|
||||
Py_DECREF(pyCostTable_m_i_l);
|
||||
PyObject* pyBackPtr_m_i_l;
|
||||
if (BACK_PTR(m, i, l) < 0)
|
||||
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
|
||||
else
|
||||
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
|
||||
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
|
||||
Py_DECREF(pyBackPtr_m_i_l);
|
||||
Py_DECREF(pyVar_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(costTable);
|
||||
free(backPtr);
|
||||
|
||||
PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr);
|
||||
Py_DECREF(pyCostTable);
|
||||
Py_DECREF(pyBackPtr);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyMethodDef rotorMethods[] = {
|
||||
{"compute_table", computeTable, METH_VARARGS,
|
||||
"Compute the optimal table with the rotor algorithm."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static struct PyModuleDef rotorModule = {
|
||||
PyModuleDef_HEAD_INIT, "rotorc", /* name of module */
|
||||
"A simple implementation of dynamic programming algorithm rotor with C in "
|
||||
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
|
||||
"https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be
|
||||
NULL */
|
||||
-1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
rotorMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); }
|
|
@ -0,0 +1,441 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.profiler import (
|
||||
activation_size,
|
||||
calculate_bwd_time,
|
||||
calculate_fwd_out,
|
||||
calculate_fwd_time,
|
||||
calculate_fwd_tmp,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
|
||||
|
||||
__all__ = ['CheckpointSolverRotor']
|
||||
|
||||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
free_memory: float = -1,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500,
|
||||
optim_multiplier: float = 1.0):
|
||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||
https://gitlab.inria.fr/hiepacs/rotor.
|
||||
|
||||
Usage:
|
||||
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using ``CheckpointSolverRotor``:
|
||||
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
free_memory (float, optional): Memory constraint for the solution, unit is byte.
|
||||
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||
``torch.optim.Optimizer``. Default to 1.0.
|
||||
"""
|
||||
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
|
||||
self.memory_slots = memory_slots
|
||||
|
||||
# construct chain
|
||||
unit = self.free_memory // self.memory_slots
|
||||
self.chain = self._construct_chain(self.graph, self.node_list)
|
||||
self.chain.discretize_all(unit)
|
||||
|
||||
self.cost_table = None
|
||||
self.back_ptr = None
|
||||
self.sequence = None
|
||||
|
||||
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
|
||||
"""Solve the checkpointing problem using rotor algorithm.
|
||||
|
||||
Args:
|
||||
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
|
||||
verbose (bool, optional): Print verbose information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
"""
|
||||
chain = self.chain
|
||||
|
||||
# compute cost table
|
||||
if force_python:
|
||||
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
|
||||
else:
|
||||
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
|
||||
|
||||
if verbose:
|
||||
self.print_chain()
|
||||
|
||||
# backtrack
|
||||
try:
|
||||
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
|
||||
self.back_ptr)
|
||||
self._annotate_from_sequence(self.sequence, self.node_list)
|
||||
except ValueError as e:
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'Checkpoint solver failed: {e}')
|
||||
raise ValueError
|
||||
|
||||
if verbose:
|
||||
self.print_sequence()
|
||||
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def print_chain(self):
|
||||
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
|
||||
for idx in range(len(self.node_list) - 1):
|
||||
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
|
||||
self.chain.btmp[idx])
|
||||
print(f'Chain = {self.chain}')
|
||||
|
||||
def print_sequence(self):
|
||||
print(f'Sequence = {self.sequence}')
|
||||
|
||||
@classmethod
|
||||
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
|
||||
input_tensors = cls._extract_input(graph)
|
||||
ftime, btime, ftmp, btmp = list(), list(), list(), list()
|
||||
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
|
||||
|
||||
for node in node_list:
|
||||
node_info = cls._extract_node_info(node)
|
||||
ftime.append(node_info[0])
|
||||
btime.append(node_info[1])
|
||||
x.append(node_info[2])
|
||||
xbar.append(node_info[3])
|
||||
ftmp.append(node_info[4])
|
||||
btmp.append(node_info[5])
|
||||
|
||||
# currently we view loss backward temp as zero
|
||||
btime.append(0)
|
||||
btmp.append(0)
|
||||
|
||||
return Chain(ftime, btime, x, xbar, ftmp, btmp)
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
|
||||
"""Extract node info from a list of nodes"""
|
||||
xbar = 0
|
||||
ftime = 0
|
||||
btime = 0
|
||||
fwd_mem_peak = 0
|
||||
for n in node:
|
||||
assert isinstance(n, Node), f'{n} is not a Node'
|
||||
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
|
||||
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
|
||||
xbar += n.meta['fwd_mem_out']
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
|
||||
else:
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||
|
||||
# minimum flop count is required
|
||||
ftime += max(calculate_fwd_time(n), 1.0)
|
||||
btime += max(calculate_bwd_time(n), 1.0)
|
||||
|
||||
x = calculate_fwd_out(node[-1])
|
||||
xbar = max(x, xbar)
|
||||
ftmp = fwd_mem_peak - xbar
|
||||
btmp = cls._extract_btmp(node)
|
||||
return ftime, btime, x, xbar, ftmp, btmp
|
||||
|
||||
@staticmethod
|
||||
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
|
||||
"""Extract input tensors from a Graph"""
|
||||
input_tensors = []
|
||||
for node in graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_tensors.append(node.meta['fwd_out'])
|
||||
return input_tensors
|
||||
|
||||
@staticmethod
|
||||
def _extract_unused_output(node: Node) -> int:
|
||||
"""Extract unused output from `torch.fx.Node`"""
|
||||
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
|
||||
|
||||
@staticmethod
|
||||
def _extract_btmp(node: List[Node]) -> int:
|
||||
"""Extract btmp from a list of nodes"""
|
||||
|
||||
def _extract_deps_size():
|
||||
deps_size = 0
|
||||
for k, v in deps.items():
|
||||
k: Node
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
||||
|
||||
return deps_size
|
||||
|
||||
btmp = 0
|
||||
deps = {}
|
||||
for n in reversed(node):
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
if deps[child] <= 0:
|
||||
deps[child] = float('-inf') # free
|
||||
return btmp
|
||||
|
||||
@staticmethod
|
||||
def _compute_table(chain: Chain, mmax: int) -> Tuple:
|
||||
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
|
||||
|
||||
Args:
|
||||
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
|
||||
mmax (int): Maximum number of memory slots.
|
||||
|
||||
Returns:
|
||||
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
|
||||
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
|
||||
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
|
||||
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
|
||||
of length j
|
||||
"""
|
||||
|
||||
ftime = chain.ftime + [0.0]
|
||||
btime = chain.btime
|
||||
x = chain.x + [0]
|
||||
xbar = chain.xbar + [0]
|
||||
ftmp = chain.ftmp + [0]
|
||||
btmp = chain.btmp + [0]
|
||||
|
||||
# Build table
|
||||
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
for m in range(mmax + 1):
|
||||
for i in range(len(chain) + 1):
|
||||
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
|
||||
if m >= limit: # Equation (1)
|
||||
cost_table[m][i][i] = ftime[i] + btime[i]
|
||||
else:
|
||||
cost_table[m][i][i] = float("inf")
|
||||
|
||||
# Compute everything
|
||||
for m in range(mmax + 1):
|
||||
for d in range(1, len(chain) + 1):
|
||||
for i in range(len(chain) + 1 - d):
|
||||
idx = i + d
|
||||
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
|
||||
if idx > i + 1:
|
||||
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
|
||||
if m < mmin:
|
||||
cost_table[m][i][idx] = float("inf")
|
||||
else:
|
||||
leaf_checkpoints = [(j,
|
||||
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
|
||||
for j in range(i + 1, idx + 1)
|
||||
if m >= x[j]]
|
||||
if leaf_checkpoints:
|
||||
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
|
||||
else:
|
||||
best_leaf = None
|
||||
if m >= xbar[i + 1]:
|
||||
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
|
||||
else:
|
||||
chain_checkpoint = float("inf")
|
||||
if best_leaf and best_leaf[1] <= chain_checkpoint:
|
||||
cost_table[m][i][idx] = best_leaf[1]
|
||||
back_ptr[m][i][idx] = (False, best_leaf[0])
|
||||
else:
|
||||
cost_table[m][i][idx] = chain_checkpoint
|
||||
back_ptr[m][i][idx] = (True,)
|
||||
return cost_table, back_ptr
|
||||
|
||||
@staticmethod
|
||||
def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
|
||||
try:
|
||||
from .rotorc import compute_table
|
||||
|
||||
# build module if module not found
|
||||
except ModuleNotFoundError:
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
logger = get_dist_logger()
|
||||
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
[
|
||||
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
|
||||
f"--build-lib={this_dir}"
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if result.wait() == 0:
|
||||
logger.info("rotorc has been built!", ranks=[0])
|
||||
from .rotorc import compute_table
|
||||
else:
|
||||
logger.warning("rotorc built failed! Using python version!", ranks=[0])
|
||||
return CheckpointSolverRotor._compute_table(chain, mmax)
|
||||
return compute_table(chain, mmax)
|
||||
|
||||
@staticmethod
|
||||
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
|
||||
back_ptr: List[Any]) -> "Sequence":
|
||||
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
|
||||
|
||||
Args:
|
||||
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
|
||||
lhs (int): The left index of the interval to backtrack.
|
||||
rhs (int): The right index of the interval to backtrack.
|
||||
budget (int): The memory budget for processing this interval.
|
||||
cost_table (List[Any]): See ``._compute_table()`` for definitions
|
||||
back_ptr (List[Any]): See ``._compute_table()`` for definitions
|
||||
|
||||
Raises:
|
||||
ValueError: Can not process the chain.
|
||||
|
||||
Returns:
|
||||
sequence (Sequence): The sequence of executing nodes with checkpoints.
|
||||
"""
|
||||
if budget <= 0:
|
||||
raise ValueError(f"Can not process a chain with negative memory {budget}")
|
||||
elif cost_table[budget][lhs][rhs] == float("inf"):
|
||||
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
|
||||
|
||||
sequence = Sequence()
|
||||
if rhs == lhs:
|
||||
if lhs == len(chain):
|
||||
sequence += [Loss()]
|
||||
else:
|
||||
sequence += [ForwardEnable(lhs), Backward(lhs)]
|
||||
return sequence
|
||||
|
||||
if back_ptr[budget][lhs][rhs][0]:
|
||||
sequence += [
|
||||
ForwardEnable(lhs),
|
||||
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
|
||||
back_ptr),
|
||||
Backward(lhs),
|
||||
]
|
||||
else:
|
||||
best_leaf = back_ptr[budget][lhs][rhs][1]
|
||||
sequence += [ForwardCheck(lhs)]
|
||||
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
|
||||
sequence += [
|
||||
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
|
||||
back_ptr),
|
||||
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
|
||||
]
|
||||
return sequence
|
||||
|
||||
@staticmethod
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
|
||||
|
||||
Args:
|
||||
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
|
||||
node_list (List[List[Node]]): The list of nodes to annotate.
|
||||
"""
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
|
||||
# forward annotation
|
||||
for idx, op in enumerate(fwd_list, 0):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(idx)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
ckpt_region.append(idx)
|
||||
|
||||
# annotate the backward if there is any nested activation checkpoint
|
||||
in_recompute = False
|
||||
for op in bwd_list:
|
||||
if in_recompute:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
else:
|
||||
if not isinstance(op, Backward):
|
||||
in_recompute = True
|
||||
ckpt_idx = 0
|
||||
ckpt_region = []
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# postprocess, make sure every activation checkpoint label in the
|
||||
# same activation checkpoint region (level = 0) has the same length
|
||||
op_list = []
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(
|
||||
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
|
||||
len(op_list[idx].meta['activation_checkpoint']))
|
|
@ -0,0 +1,184 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import Any, Iterable, List
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
class Chain:
|
||||
|
||||
def __init__(self,
|
||||
ftime: List[float],
|
||||
btime: List[float],
|
||||
x: List[int],
|
||||
xbar: List[int],
|
||||
ftmp: List[int],
|
||||
btmp: List[int],
|
||||
check_consistency: bool = True):
|
||||
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
|
||||
See paper https://hal.inria.fr/hal-02352969 for details.
|
||||
|
||||
Args:
|
||||
ftime (List[float]): The forward time of each node.
|
||||
btime (List[float]): The backward time of each node.
|
||||
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
|
||||
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
|
||||
ftmp (List[int]): The temporary forward memory of each node.
|
||||
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
|
||||
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
|
||||
"""
|
||||
self.ftime = ftime
|
||||
self.btime = btime
|
||||
self.x = x
|
||||
self.xbar = xbar
|
||||
self.ftmp = ftmp
|
||||
self.btmp = btmp
|
||||
if check_consistency and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
|
||||
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
|
||||
and (len(self.xbar) == len(self) + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(len(self)):
|
||||
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
|
||||
i = len(self)
|
||||
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ftime)
|
||||
|
||||
def discretize_all(self, unit: int):
|
||||
"""Discretize the chain into a list of chains according to unit size."""
|
||||
discretizer = lambda val: math.ceil(val / unit)
|
||||
self.x = tree_map(discretizer, self.x)
|
||||
self.xbar = tree_map(discretizer, self.xbar)
|
||||
self.ftmp = tree_map(discretizer, self.ftmp)
|
||||
self.btmp = tree_map(discretizer, self.btmp)
|
||||
|
||||
|
||||
class Operation(ABC):
|
||||
name = "Op"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.name}_{self.index}"
|
||||
|
||||
def shift(self, value):
|
||||
if type(self.index) is tuple:
|
||||
self.index = tuple(x + value for x in self.index)
|
||||
else:
|
||||
self.index += value
|
||||
|
||||
|
||||
class Forward(Operation):
|
||||
name = "F"
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.ftime[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class ForwardEnable(Forward):
|
||||
name = "Fe"
|
||||
|
||||
|
||||
class ForwardNograd(Forward):
|
||||
name = "Fn"
|
||||
|
||||
|
||||
class ForwardCheck(Forward):
|
||||
name = "CF"
|
||||
|
||||
|
||||
class Forwards(Operation):
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.index = (start, end)
|
||||
|
||||
def __repr__(self):
|
||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
|
||||
|
||||
def isForward(op):
|
||||
return type(op) is Forward or type(op) is Forwards
|
||||
|
||||
|
||||
class Backward(Operation):
|
||||
name = "B"
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.btime[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class Loss(Operation):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "L"
|
||||
|
||||
def cost(self, chain):
|
||||
return 0
|
||||
|
||||
|
||||
class MemoryAccess(Operation):
|
||||
name = "MA"
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
return 0
|
||||
|
||||
|
||||
class WriteMemory(MemoryAccess):
|
||||
name = "WM"
|
||||
|
||||
|
||||
class ReadMemory(MemoryAccess):
|
||||
name = "RM"
|
||||
|
||||
|
||||
class DiscardMemory(MemoryAccess):
|
||||
name = "DM"
|
||||
|
||||
|
||||
class Sequence(list):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.list_operations())
|
||||
|
||||
def list_operations(self):
|
||||
op_list = []
|
||||
for x in self:
|
||||
if isinstance(x, Operation):
|
||||
op_list.append(x)
|
||||
else:
|
||||
assert isinstance(x, Sequence)
|
||||
op_list += x.list_operations()
|
||||
return op_list
|
|
@ -0,0 +1,3 @@
|
|||
from .meta_registry import *
|
||||
from .metainfo import *
|
||||
from .registry import meta_register
|
|
@ -0,0 +1,15 @@
|
|||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..tensor_shard.constants import *
|
||||
|
||||
# list of inplace module
|
||||
INPLACE_MODULE = [nn.ReLU]
|
||||
|
||||
# list of inplace operations
|
||||
INPLACE_OPS = [torch.flatten]
|
||||
|
||||
# list of operations that do not save forward activations
|
||||
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
|
|
@ -0,0 +1,6 @@
|
|||
from .activation import *
|
||||
from .binary_elementwise_ops import *
|
||||
from .conv import *
|
||||
from .linear import *
|
||||
from .norm import *
|
||||
from .pooling import *
|
|
@ -0,0 +1,74 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ["relu_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.ReLU)
|
||||
def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.ReLU metainfo generator
|
||||
The aten graph of torch.nn.ReLU is
|
||||
graph():
|
||||
%input_2 : [#users=1] = placeholder[target=placeholder](default=)
|
||||
%relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {})
|
||||
%threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
is_inplace = kwargs.get("inplace", False)
|
||||
|
||||
# construct input args for forward
|
||||
fwd_in_args = [input_tensor]
|
||||
|
||||
# construct input args for backward
|
||||
bwd_in_args = [output_tensor]
|
||||
|
||||
# calculate cost
|
||||
# the fwd op with compute cost is relu.default
|
||||
# the bwd op with compute cost is threshold_backward
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: the inplace ReLU don't have forward memory cost
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
# NOTE: It might seems a little bit weird here, we just want to align it with the older version
|
||||
# of MetaInfoProp. In the future we might modify this part to make it clearer.
|
||||
fwd_in = []
|
||||
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['binary_elementwise_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(BCAST_FUNC_OP)
|
||||
def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""Meta information generator for binary elementwise operations
|
||||
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
|
||||
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
|
||||
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
|
||||
this behavior, it is critical for better memory estimation.
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
||||
|
||||
# construct forward args for flop mapping
|
||||
fwd_in_args = [opdata.data for opdata in input_op_data]
|
||||
fwd_out_args = [output_op_data.data]
|
||||
|
||||
# calculate cost
|
||||
|
||||
# calculate compute cost
|
||||
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
|
||||
bwd_compute_cost = fwd_compute_cost * 2
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size(output_op_data.data),
|
||||
parameter=param_mem_cost,
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=activation_size(fwd_in_args),
|
||||
parameter=param_mem_cost,
|
||||
)
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(
|
||||
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
|
@ -0,0 +1,137 @@
|
|||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['convnd_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.Conv1d)
|
||||
@meta_register.register(torch.nn.Conv2d)
|
||||
@meta_register.register(torch.nn.Conv3d)
|
||||
@meta_register.register(torch.nn.functional.conv1d)
|
||||
@meta_register.register(torch.nn.functional.conv2d)
|
||||
@meta_register.register(torch.nn.functional.conv3d)
|
||||
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
|
||||
The atens graph of torch.nn.Convnd with bias is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
|
||||
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
|
||||
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
|
||||
|
||||
The atens graph of torch.nn.Convnd without bias is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
|
||||
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
if len(args) == 4:
|
||||
weight_tensors = [args[1].data, args[3].data]
|
||||
else:
|
||||
weight_tensors = [args[1].data]
|
||||
|
||||
# check if conv has bias
|
||||
if len(weight_tensors) > 1:
|
||||
has_bias = True
|
||||
# bias tensor's shape only has one dimension
|
||||
if len(weight_tensors[0].shape) == 1:
|
||||
bias_tensor, weight_tensor = weight_tensors
|
||||
else:
|
||||
weight_tensor, bias_tensor = weight_tensors
|
||||
|
||||
else:
|
||||
weight_tensor = weight_tensors[0]
|
||||
|
||||
# construct input args for forward
|
||||
fwd_args = [None] * 9
|
||||
|
||||
# weight and input
|
||||
fwd_args[0] = input_tensor
|
||||
fwd_args[1] = weight_tensor
|
||||
fwd_args[2] = bias_tensor if has_bias else None
|
||||
|
||||
# transpose indicator should be set to False
|
||||
fwd_args[6] = False
|
||||
|
||||
# construct input args for backward
|
||||
bwd_args = [None] * 11
|
||||
|
||||
# weight and input
|
||||
bwd_args[0] = output_tensor
|
||||
bwd_args[1] = input_tensor
|
||||
bwd_args[2] = weight_tensor
|
||||
bwd_args[-1] = [True, True, True] if has_bias else [True, True, False]
|
||||
|
||||
# calculate cost
|
||||
# the fwd op with compute cost is convolution.default
|
||||
# the bwd op with compute cost is convolution_backward.default
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
|
||||
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# TODO: use profiler to check conv temp memory
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias else activation_size([input_tensor, weight_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
|
@ -0,0 +1,172 @@
|
|||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['linear_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
@meta_register.register(torch.nn.Linear)
|
||||
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
|
||||
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
|
||||
but we will hold the bias mechanism in the linear metainfo generator for future use.
|
||||
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
|
||||
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
|
||||
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
|
||||
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
|
||||
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
|
||||
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
|
||||
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
|
||||
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
|
||||
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
|
||||
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
|
||||
|
||||
The one without bias is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
|
||||
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
|
||||
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
|
||||
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
|
||||
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
|
||||
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
input_tensor = args[0].data
|
||||
output_tensor = args[2].data
|
||||
if len(args) == 4:
|
||||
weight_tensors = [args[1].data, args[3].data]
|
||||
else:
|
||||
weight_tensors = [args[1].data]
|
||||
|
||||
# process the dimension of input and output
|
||||
if len(input_tensor.shape) > 2:
|
||||
input_tensor: torch.Tensor
|
||||
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
|
||||
|
||||
if len(output_tensor.shape) > 2:
|
||||
output_tensor: torch.Tensor
|
||||
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
|
||||
|
||||
if len(weight_tensors) > 1:
|
||||
has_bias = True
|
||||
if len(weight_tensors[0].shape) == 2:
|
||||
weight_tensor, bias_tensor = weight_tensors
|
||||
else:
|
||||
bias_tensor, weight_tensor = weight_tensors
|
||||
else:
|
||||
weight_tensor = weight_tensors[0]
|
||||
|
||||
if has_bias:
|
||||
# calculate cost with bias
|
||||
# the fwd op with compute cost is addmm
|
||||
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
|
||||
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
else:
|
||||
# calculate cost without bias
|
||||
# the fwd op with compute cost is mm
|
||||
# the bwd op with compute cost is mm * 2
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
|
@ -0,0 +1,103 @@
|
|||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['batchnormnd_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.BatchNorm1d)
|
||||
@meta_register.register(torch.nn.BatchNorm2d)
|
||||
@meta_register.register(torch.nn.BatchNorm3d)
|
||||
def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator
|
||||
The aten graph of BatchNorm2d is like
|
||||
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
|
||||
%cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {})
|
||||
%detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
|
||||
%detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {})
|
||||
%detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
|
||||
%detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {})
|
||||
%detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
|
||||
%detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {})
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
|
||||
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
|
||||
mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data
|
||||
var_tensor = next(filter(lambda x: x.name == "running_var", args)).data
|
||||
num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data
|
||||
|
||||
# construct fwd args
|
||||
# the fwd inputs are input, weight, bias, running_mean, running_var and some other args
|
||||
# indicating the status of the module
|
||||
# the fwd outputs are output, saved mean, saved inv std and num batches tracked
|
||||
fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5]
|
||||
fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch]
|
||||
|
||||
# construct bwd args
|
||||
# the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean,
|
||||
# saved inv std and some other args indicating the status of the module
|
||||
# the bwd outputs are input grad, weight grad and bias grad
|
||||
bwd_in_args = [
|
||||
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
|
||||
]
|
||||
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
|
||||
|
||||
# calculate cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args)
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# the fwd activation cost is output plus saved mean and saved inv std
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=activation_size([mean_tensor, var_tensor]))
|
||||
|
||||
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
|
||||
# and saved inv std during backward phase
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=activation_size([mean_tensor, var_tensor]),
|
||||
buffer=activation_size([mean_tensor, var_tensor]))
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
|
@ -0,0 +1,134 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ["avgpool_meta_info", "maxpool_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.AdaptiveAvgPool1d)
|
||||
@meta_register.register(torch.nn.AdaptiveAvgPool2d)
|
||||
@meta_register.register(torch.nn.AdaptiveAvgPool3d)
|
||||
@meta_register.register(torch.flatten)
|
||||
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""Meta info for AdaptiveAvgPool
|
||||
The aten graph of AdaptiveAvgPool is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
is_inplace = kwargs.get("inplace", False)
|
||||
|
||||
# construct forward args for flop mapping
|
||||
fwd_in_args = [input_tensor]
|
||||
fwd_out_args = [output_tensor]
|
||||
|
||||
# construct backward args for flop mapping
|
||||
bwd_in_args = [output_tensor]
|
||||
bwd_out_args = [input_tensor]
|
||||
|
||||
# calculate cost
|
||||
# the fwd op with compute cost is _adaptive_avg_pool2d.default
|
||||
# the bwd op with compute cost is _adaptive_avg_pool2d_backward.default
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args)
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
|
||||
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
|
||||
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.MaxPool1d)
|
||||
@meta_register.register(torch.nn.MaxPool2d)
|
||||
@meta_register.register(torch.nn.MaxPool3d)
|
||||
def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""Meta info for MaxPool
|
||||
The aten graph of MaxPool is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {})
|
||||
%max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {})
|
||||
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {})
|
||||
%detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
|
||||
# construct forward args for flop mapping
|
||||
fwd_in_args = [input_tensor]
|
||||
fwd_out_args = [output_tensor]
|
||||
|
||||
# construct backward args for flop mapping
|
||||
bwd_in_args = [output_tensor]
|
||||
bwd_out_args = [input_tensor]
|
||||
|
||||
# construct index matrix
|
||||
index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64)
|
||||
|
||||
# calculate cost
|
||||
# the fwd op with compute cost is max_pool2d_with_indices.default
|
||||
# the bwd op with compute cost is max_pool2d_with_indices_backward.default
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args)
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: the index matrix will be discarded in backward phase
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
|
||||
|
||||
# temp memory for backward is the index matrix to be discarded
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
|
||||
temp=activation_size(index_matrix))
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
|
||||
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
|
@ -0,0 +1,117 @@
|
|||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['MetaInfo']
|
||||
|
||||
|
||||
class MetaInfo:
|
||||
"""MetaInfo class
|
||||
This class is used to store meta info based on sharding strategy and the given
|
||||
target function.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
|
||||
# compute cost of forward and backward computation
|
||||
self.compute_cost: TrainCycleItem
|
||||
|
||||
# compute memory cost of forward and backward phase
|
||||
self.memory_cost: TrainCycleItem
|
||||
|
||||
# list of input tensors
|
||||
self.fwd_in: List[torch.Tensor]
|
||||
|
||||
# list of buffer tensors
|
||||
self.fwd_buffer: List[torch.Tensor]
|
||||
|
||||
# list of output tensors
|
||||
self.fwd_out: List[torch.Tensor]
|
||||
|
||||
# sharding strategy
|
||||
self._strategy = strategy
|
||||
|
||||
# target function
|
||||
self._target = target
|
||||
|
||||
# compute metainfo if possible
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
@property
|
||||
def strategy(self) -> ShardingStrategy:
|
||||
return self._strategy
|
||||
|
||||
@property
|
||||
def target(self) -> Callable:
|
||||
return self._target
|
||||
|
||||
@strategy.setter
|
||||
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
@target.setter
|
||||
def target(self, target: Callable) -> None:
|
||||
self._target = target
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""
|
||||
Compute sharded opdata based on the given data and sharding spec.
|
||||
"""
|
||||
return OperationData(name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
|
||||
def compute_metainfo(self):
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
|
||||
f"Meta info for {self._target} is not registered."
|
||||
if meta_register.has(self._target.__class__):
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
|
||||
# check whether the target in the list that we don't need to save activation
|
||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
else:
|
||||
# function
|
||||
meta_func = meta_register.get(self._target)
|
||||
|
||||
# check whether the target in the list that we don't need to save activation
|
||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
kwargs = {'inplace': self.target.inplace}
|
||||
elif self.target in INPLACE_OPS:
|
||||
kwargs = {'inplace': True}
|
||||
else:
|
||||
kwargs = {'inplace': False}
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
|
||||
|
||||
# process corner case for NO_SAVE_ACTIVATION
|
||||
if not save_fwd_in:
|
||||
self.fwd_in = []
|
|
@ -0,0 +1,32 @@
|
|||
__all__ = ['Registry']
|
||||
|
||||
|
||||
class Registry:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
if isinstance(source, (list, tuple)):
|
||||
# support register a list of items for this func
|
||||
for element in source:
|
||||
self.store[element] = func
|
||||
else:
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
def has(self, source):
|
||||
return source in self.store
|
||||
|
||||
|
||||
meta_register = Registry('meta')
|
|
@ -0,0 +1,113 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
meta_info = MetaInfo()
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
# extract user that has _meta_data and extract element length
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||
element_length = input_node._meta_data.element_size()
|
||||
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
origin_sharding_spec: ShardingSpec
|
||||
target_sharding_spec: ShardingSpec
|
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||
output_shape = target_sharding_spec.get_sharded_shape_per_device()
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
"""
|
||||
|
||||
# extract node index and user node index
|
||||
args = node.args
|
||||
node_index, user_node_index = args[3], args[4]
|
||||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||
user_node_index]
|
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||
# extract node_index and op_data_name
|
||||
node_index, op_data_name = node.args[2], node.args[3]
|
||||
|
||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
# this case is for all_reduce, there will be no memory cost
|
||||
meta_info = MetaInfo()
|
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||
element_length = output_node._meta_data.element_size()
|
||||
|
||||
total_cost = comm_action.comm_spec.get_comm_cost()
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
else:
|
||||
# this case will be handled by shape consistency manager
|
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||
'tgt_spec']
|
||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
|
||||
comm_actions_dict: Dict) -> GraphModule:
|
||||
"""
|
||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
||||
return gm
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
|
||||
|
||||
OUTPUT_SAVED_MOD = [
|
||||
torch.nn.ReLU,
|
||||
torch.nn.Softmax,
|
||||
]
|
|
@ -0,0 +1,165 @@
|
|||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class MetaInfoProp:
|
||||
|
||||
def __init__(self, module: GraphModule) -> None:
|
||||
self.module = module
|
||||
self.func_dict = {
|
||||
'placeholder': self.placeholder_handler,
|
||||
'get_attr': self.get_attr_handler,
|
||||
'output': self.output_handler,
|
||||
'call_function': self.node_handler,
|
||||
'call_module': self.node_handler,
|
||||
'call_method': self.node_handler,
|
||||
}
|
||||
|
||||
def _set_data_ptr(self, x):
|
||||
"""
|
||||
Set uuid to tensor
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
if not x.data_ptr():
|
||||
data_ptr = uuid.uuid4()
|
||||
x.data_ptr = lambda: data_ptr
|
||||
|
||||
def _is_inplace(self, node: Node):
|
||||
"""
|
||||
Check if the node is inplace operation.
|
||||
"""
|
||||
if node.op == 'call_module':
|
||||
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||
elif node.op == "call_function":
|
||||
return node.target in OUTPUT_SAVED_OPS
|
||||
return False
|
||||
|
||||
def run(self) -> GraphModule:
|
||||
"""
|
||||
Run the meta information propagation pass on the module.
|
||||
"""
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
self.func_dict[node.op](node)
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def placeholder_handler(self, node: Node) -> None:
|
||||
"""
|
||||
Handle the placeholder node.
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_attr_handler(self, node: Node) -> None:
|
||||
"""
|
||||
Handle the get_attr node.
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def output_handler(self, node: Node) -> None:
|
||||
"""
|
||||
Handle the output node.
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
output_tensors = []
|
||||
for par in node._input_nodes:
|
||||
if par.meta:
|
||||
output_tensors += par.meta["fwd_out"]
|
||||
graph_info.fwd_in = output_tensors
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def node_handler(self, node: Node) -> None:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
|
||||
# set data_ptr for input_tensor in MetaInfo class
|
||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||
|
||||
if self._is_inplace(node):
|
||||
# inplace operation will not create new tensor, and it only has one parent node
|
||||
# TODO: Verify this observation
|
||||
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
|
||||
parent_node = list(node._input_nodes.keys())[0]
|
||||
parent_tensor = parent_node.meta.get("fwd_out")[0]
|
||||
parent_tensor: torch.Tensor
|
||||
for tensor in input_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
for tensor in buffer_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
for tensor in output_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
|
||||
else:
|
||||
for par in node._input_nodes:
|
||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||
for tensor in par.meta.get("fwd_out", []):
|
||||
tensor: torch.Tensor
|
||||
target_input_tensor = next(
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
if target_input_tensor is not None:
|
||||
target_input_tensor.data_ptr = tensor.data_ptr
|
||||
|
||||
# set data_ptr for tensor in input_tensor that is not set
|
||||
for tensor in input_tensors:
|
||||
if not tensor.data_ptr():
|
||||
self._set_data_ptr(tensor)
|
||||
|
||||
# set data_ptr for buffer_tensor
|
||||
for tensor in buffer_tensors:
|
||||
self._set_data_ptr(tensor)
|
||||
|
||||
# set data_ptr for output_tensor
|
||||
for tensor in output_tensors:
|
||||
self._set_data_ptr(tensor)
|
||||
|
||||
# attach them to graph_info
|
||||
graph_info.fwd_in = input_tensors
|
||||
graph_info.fwd_tmp = buffer_tensors
|
||||
graph_info.fwd_out = output_tensors
|
||||
|
||||
# fetch other memory informations
|
||||
memory_cost = meta_info.memory_cost
|
||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||
graph_info.fwd_mem_out = memory_cost.fwd.activation
|
||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||
|
||||
# fetch flop information
|
||||
# here we use fwd_time and bwd_time to deal with the case that
|
||||
# communication cost is a float
|
||||
compute_cost = meta_info.compute_cost
|
||||
graph_info.fwd_time = compute_cost.fwd
|
||||
graph_info.bwd_time = compute_cost.bwd
|
||||
|
||||
node.meta = {**asdict(graph_info)}
|
|
@ -0,0 +1,221 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int):
|
||||
"""
|
||||
This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into
|
||||
the user node expected form.
|
||||
"""
|
||||
origin_sharding_spec = origin_dict[node_index]
|
||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
|
||||
user_node_index: int):
|
||||
"""
|
||||
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
|
||||
is converted into the user node expected form.
|
||||
"""
|
||||
rst = []
|
||||
for index, (origin_sharding_spec,
|
||||
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
|
||||
input_dict[node_index][user_node_index])):
|
||||
rst.append(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
|
||||
target_sharding_spec))
|
||||
rst = type(node)(rst)
|
||||
return rst
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
"""
|
||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
|
||||
else:
|
||||
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
||||
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
||||
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||
return rst
|
||||
|
||||
|
||||
def _preprocess_graph(nodes: List[Node]):
|
||||
"""
|
||||
This method is used to extract all the placeholders with sharding information,
|
||||
and mapping the nodes into the index of the origin graph.
|
||||
"""
|
||||
# mapping the node into the origin graph index
|
||||
node_to_index_dict = {}
|
||||
index = 0
|
||||
for node in nodes:
|
||||
if node.target == 'sharding_spec_convert_dict':
|
||||
input_dict_node = node
|
||||
continue
|
||||
if node.target == 'origin_node_sharding_spec_dict':
|
||||
origin_dict_node = node
|
||||
continue
|
||||
if node.target == 'comm_actions_dict':
|
||||
comm_actions_dict_node = node
|
||||
continue
|
||||
if not hasattr(node, 'best_strategy'):
|
||||
continue
|
||||
node_to_index_dict[node] = index
|
||||
index += 1
|
||||
|
||||
return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict
|
||||
|
||||
|
||||
def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
This pass is used to add the shape consistency node to the origin graph.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
continue
|
||||
|
||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||
if isinstance(node.sharding_spec, (list, tuple)):
|
||||
assert isinstance(
|
||||
node.target_sharding_specs,
|
||||
(list,
|
||||
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
|
||||
total_difference = 0
|
||||
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
|
||||
node.target_sharding_specs[user_node_index]):
|
||||
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
|
||||
if total_difference == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
else:
|
||||
assert isinstance(node.sharding_spec,
|
||||
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with shape_consistency_node
|
||||
origin_index_args = new_args.index(node)
|
||||
new_args[origin_index_args] = shape_consistency_node
|
||||
user_node.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with shape_consistency_node
|
||||
new_kwargs[str(node)] = shape_consistency_node
|
||||
user_node.kwargs = new_kwargs
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
This pass is used to add the comm spec apply node to the origin graph.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
continue
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
|
||||
if comm_action.comm_type == CommType.HOOK:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
if op_data.type == OperationDataType.OUTPUT:
|
||||
comm_object = node
|
||||
elif comm_action.key_for_kwarg is not None:
|
||||
comm_object = node.kwargs[comm_action.key_for_kwarg]
|
||||
else:
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
with mod_graph.inserting_before(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs = dict(node.kwargs)
|
||||
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
|
||||
node.kwargs = new_kwargs
|
||||
else:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(node, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == comm_spec_apply_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_args[new_args.index(node)] = comm_spec_apply_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
return gm
|
||||
|
||||
|
||||
def runtime_apply_pass(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
The method manages all the passes acting on the distributed training runtime.
|
||||
"""
|
||||
gm = _shape_consistency_apply(gm)
|
||||
gm = _comm_spec_apply(gm)
|
||||
|
||||
return gm
|
|
@ -0,0 +1,471 @@
|
|||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import _all_reduce
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def size_processing(size: Union[int, torch.Size],
|
||||
dim_partition_dict: Dict[int, List[int]],
|
||||
device_mesh_info: Dict[int, int],
|
||||
target_dim: int = None,
|
||||
node_name: str = None):
|
||||
"""
|
||||
This method will be invoked during runtime to convert size node value depending on distributed information.
|
||||
"""
|
||||
if target_dim is not None:
|
||||
assert isinstance(size, int)
|
||||
if target_dim in dim_partition_dict:
|
||||
total_shard_size = 1
|
||||
for shard_dim in dim_partition_dict[target_dim]:
|
||||
total_shard_size *= device_mesh_info[shard_dim]
|
||||
size = size * total_shard_size
|
||||
|
||||
else:
|
||||
size = list(size)
|
||||
for dim, dim_size in enumerate(size):
|
||||
if dim in dim_partition_dict:
|
||||
total_shard_size = 1
|
||||
for shard_dim in dim_partition_dict[dim]:
|
||||
total_shard_size *= device_mesh_info[shard_dim]
|
||||
size[dim] = dim_size * total_shard_size
|
||||
size = torch.Size(size)
|
||||
|
||||
return size
|
||||
|
||||
|
||||
def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
"""
|
||||
This method is used to stick the solution strategy to the nodes and add the information
|
||||
required in runtime into graph as placeholder nodes.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
# TODO: In future PR, strategies_constructor should be a required argument,
|
||||
# instead of optional argument. This is because we don't need to consider nodes with
|
||||
# no strategy in runtime preparation pass.
|
||||
if strategies_constructor is not None:
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
||||
else:
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
no_strategy_nodes = []
|
||||
|
||||
# the dict to get origin sharding spec of node
|
||||
origin_node_sharding_spec_dict = {}
|
||||
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
||||
strategies_vector = node.strategies_vector
|
||||
# stick the solution strategy to the corresponding node
|
||||
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
||||
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
|
||||
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
|
||||
if hasattr(node, 'metainfo_vector'):
|
||||
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
# the dict to record comm actions of nodes
|
||||
comm_actions_dict = {}
|
||||
for index, node in enumerate(nodes):
|
||||
target_sharding_specs = []
|
||||
for user_node in node.strategies_vector.successor_nodes:
|
||||
if user_node in no_strategy_nodes:
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
else:
|
||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||
target_node = node.strategies_vector.successor_nodes[0]
|
||||
node_name = str(node)
|
||||
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
|
||||
node_name = str(target_node)
|
||||
target_node = target_node.strategies_vector.successor_nodes[0]
|
||||
user_strategy = target_node.best_strategy
|
||||
op_data_in_user = user_strategy.get_op_data_by_name(node_name)
|
||||
origin_pending_strategy = node.best_strategy
|
||||
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
|
||||
|
||||
new_communication_actions = {}
|
||||
if op_data_in_user in user_strategy.communication_actions:
|
||||
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
|
||||
new_communication_action.arg_index = 0
|
||||
new_communication_actions[origin_op_data] = new_communication_action
|
||||
node.best_strategy.communication_actions = new_communication_actions
|
||||
|
||||
comm_action_dict = {}
|
||||
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
||||
comm_action_dict[op_data.name] = comm_action
|
||||
comm_actions_dict[index] = comm_action_dict
|
||||
|
||||
# add above dicts into graph
|
||||
for node in nodes:
|
||||
if node.op != 'placeholder':
|
||||
with mod_graph.inserting_before(node):
|
||||
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
||||
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
||||
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
|
||||
break
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
||||
|
||||
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
|
||||
need to be converted to the size of original tensor and managed by the users, such as torch.view,
|
||||
torch.reshape, etc. These nodes have enough information like input sharding_spec and
|
||||
output sharding_spec to decide how to convert the size value.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
node_pairs = {}
|
||||
|
||||
for node in nodes:
|
||||
|
||||
if node.op == 'call_method' and node.target == 'size':
|
||||
# extract useful information from size node
|
||||
# dim_partition_dict will instruct the size value on which
|
||||
# dimension should be enlarged.
|
||||
sharding_spec = node.args[0].sharding_spec
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# there are two usages of torch.Tensor.size:
|
||||
# tensor.size()
|
||||
# tensor.size(dim)
|
||||
# if a target_dim is assigned, then the output will be
|
||||
# in type of int, instead of torch.Size
|
||||
target_dim = None
|
||||
if len(node.args) > 1:
|
||||
target_dim = node.args[1]
|
||||
if target_dim < 0:
|
||||
target_dim += node.args[0]._meta_data.dim()
|
||||
|
||||
# DeviceMesh information instructs the scaling of the size value
|
||||
device_mesh_info = {}
|
||||
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
||||
device_mesh_info[dim] = dim_size
|
||||
|
||||
with mod_graph.inserting_after(node):
|
||||
size_processing_node = mod_graph.create_node('call_function',
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info,
|
||||
target_dim, node.name))
|
||||
# store original node and processing node pair in node_pairs dictioanry
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == size_processing_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_args[new_args.index(node)] = size_processing_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_kwargs[str(node)] = size_processing_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if node.op == 'call_function' and node.target == operator.getitem:
|
||||
|
||||
getitem_index = node.args[1]
|
||||
# slice object is quite special in torch.fx graph,
|
||||
# On one side, we treat slice object same as type of int,
|
||||
# so we do not create a node for slice object. On the other side,
|
||||
# slice object could take fx.Node as its argument. And the user
|
||||
# relationship cannot be tracked in fx graph.
|
||||
# Therefore, I record the node_pairs in this pass, and use the it
|
||||
# to replace the original node argument inside the slice object if
|
||||
# it has been processed in above pass.
|
||||
|
||||
# There are three main usages of operator.getitem:
|
||||
# getitem(input, int)
|
||||
# getitem(input, slice)
|
||||
# getitem(input, Tuple[slice])
|
||||
# In this pass, we need process the last two cases because
|
||||
# node arguments may potentially appear in these cases.
|
||||
if isinstance(getitem_index, slice):
|
||||
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
|
||||
if getitem_index.start in node_pairs:
|
||||
new_start = node_pairs[getitem_index.start]
|
||||
elif getitem_index.stop in node_pairs:
|
||||
new_stop = node_pairs[getitem_index.stop]
|
||||
elif getitem_index.step in node_pairs:
|
||||
new_step = node_pairs[getitem_index.step]
|
||||
new_slice_item = slice(new_start, new_stop, new_step)
|
||||
new_args = (node.args[0], new_slice_item)
|
||||
node.args = new_args
|
||||
|
||||
elif isinstance(getitem_index, (tuple, list)):
|
||||
if not isinstance(getitem_index[0], slice):
|
||||
continue
|
||||
new_slice_items = []
|
||||
|
||||
for slice_item in getitem_index:
|
||||
if slice_item is None:
|
||||
new_slice_items.append(None)
|
||||
continue
|
||||
|
||||
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
|
||||
|
||||
if slice_item.start in node_pairs:
|
||||
new_start = node_pairs[slice_item.start]
|
||||
elif slice_item.stop in node_pairs:
|
||||
new_stop = node_pairs[slice_item.stop]
|
||||
elif slice_item.step in node_pairs:
|
||||
new_step = node_pairs[slice_item.step]
|
||||
new_slice_item = slice(new_start, new_stop, new_step)
|
||||
new_slice_items.append(new_slice_item)
|
||||
|
||||
new_args = (node.args[0], tuple(new_slice_items))
|
||||
node.args = new_args
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
This pass will process node args to adapt the distributed tensor layout.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
for node in nodes:
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
continue
|
||||
|
||||
def _process_sharding_spec(sharding_spec):
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
for element in sharding_spec:
|
||||
dim_partition_dict.append(_process_sharding_spec(element))
|
||||
return dim_partition_dict, sharding_spec
|
||||
|
||||
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
|
||||
new_args = []
|
||||
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
# process the node with (input, *shape) style args
|
||||
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
||||
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (int, tuple, list)):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
# There are two ways to use torch.view:
|
||||
# 1. torch.view(input, *shape)
|
||||
# 2. torch.view(input, shape)
|
||||
if isinstance(new_args[1], int):
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[dim + 1] == -1:
|
||||
continue
|
||||
else:
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
else:
|
||||
new_args[1] = list(new_args[1])
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[1][dim] == -1:
|
||||
continue
|
||||
else:
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
# process the node with (input, torch.Size) style args
|
||||
if target in (torch.reshape,):
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (tuple, list)):
|
||||
new_args.append(list(arg._meta_data))
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
|
||||
new_args.append(list(arg))
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[1][dim] == -1:
|
||||
continue
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
Apply the sharding action to the module parameters and buffers following the
|
||||
instructions of solver solution.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
# This stream is created for overlaping the communication and computation.
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
# TODO: we need to do more actions to take care of the shared parameters.
|
||||
if hasattr(target_module, 'processed') and target_module.processed:
|
||||
continue
|
||||
setattr(target_module, 'processed', True)
|
||||
for name, param in target_module.named_parameters():
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
# apply the sharding spec of parameters
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
|
||||
|
||||
setattr(target_module, name, param)
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param, comm_spec_to_use)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
for name, buffer in target_module.named_buffers():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||
sharded_buffer_dict[name] = buffer_sharded
|
||||
|
||||
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||
|
||||
if node.op == 'get_attr':
|
||||
root = node.graph.owning_module
|
||||
atoms = node.target.split(".")
|
||||
attr_len = len(atoms)
|
||||
if attr_len == 1:
|
||||
target_module = root
|
||||
target = getattr(root, atoms[0])
|
||||
else:
|
||||
target_module = root
|
||||
for atom in atoms[:-1]:
|
||||
target_module = getattr(target_module, atom)
|
||||
target = getattr(target_module, atoms[-1])
|
||||
|
||||
target_sharding_spec = node.sharding_spec
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
||||
setattr(target, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
|
||||
|
||||
assert hasattr(target_module, atoms[-1])
|
||||
setattr(target_module, atoms[-1], target)
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target, comm_spec_to_use)
|
||||
return gm
|
||||
|
||||
|
||||
def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
replace the origin kernel into kernel with implicit communication inside.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||
gm, solution, strategies_constructor)
|
||||
gm = _size_value_converting(gm, device_mesh)
|
||||
gm = _node_args_converting(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
# gm = implicit_comm_action_apply(gm)
|
||||
gm = _module_params_sharding(gm, device_mesh)
|
||||
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
|
||||
|
@ -25,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
|
|||
# TODO: contiguous maybe need some extra processes.
|
||||
torch.Tensor.contiguous
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||
RESHAPE_FUNC_OP = [
|
||||
torch.flatten,
|
||||
torch.reshape,
|
||||
torch.transpose,
|
||||
torch.split,
|
||||
torch.permute,
|
||||
operator.getitem,
|
||||
]
|
||||
RESHAPE_METHOD_OP = [
|
||||
torch.Tensor.view,
|
||||
torch.Tensor.unsqueeze,
|
||||
|
@ -35,7 +43,7 @@ RESHAPE_METHOD_OP = [
|
|||
]
|
||||
BCAST_FUNC_OP = [
|
||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .options import SolverOptions
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .options import SolverOptions
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .solver import Solver
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
|
|
@ -5,10 +5,11 @@ from functools import reduce
|
|||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
@ -17,7 +18,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
|||
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
|
||||
|
@ -58,7 +59,7 @@ def generate_resharding_costs(nodes: List[Node],
|
|||
nodes (List[Node]): a list of nodes
|
||||
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
|
|
|
@ -3,9 +3,9 @@ import warnings
|
|||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
|
@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler):
|
|||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
|
@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler):
|
|||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
|
||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
|
||||
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
|
||||
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy_into_strategies_vector()
|
||||
for strategy in conv_handler.strategies_vector:
|
||||
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
|
||||
|
||||
|
||||
Output:
|
||||
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
|
|
|
@ -6,9 +6,9 @@ from typing import List
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from .operator_handler import OperatorHandler
|
||||
|
@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator):
|
|||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
||||
|
||||
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
||||
|
||||
Args:
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
||||
"""
|
||||
|
||||
|
@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
|||
"""
|
||||
Generate sharding strategies for the batched matrix multiplication.
|
||||
|
||||
A batched matrix multiplication can be viewed as
|
||||
A batched matrix multiplication can be viewed as
|
||||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
"""
|
||||
|
||||
|
@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -451,7 +451,7 @@ class DotHandler(OperatorHandler):
|
|||
|
||||
# create and register strategy
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -491,7 +491,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -529,7 +529,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -564,7 +564,7 @@ class DotHandler(OperatorHandler):
|
|||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -600,7 +600,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -636,7 +636,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
|
||||
communication_cost = communication_cost_weight_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -673,7 +673,7 @@ class DotHandler(OperatorHandler):
|
|||
activation_memory_cost, 0)
|
||||
communication_cost = communication_cost_forward_activation
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -709,7 +709,7 @@ class DotHandler(OperatorHandler):
|
|||
input_grad_memory_cost, 0)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
|
|
@ -2,10 +2,14 @@ import operator
|
|||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size, ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
|
@ -63,19 +67,19 @@ class LayerNormHandler(OperatorHandler):
|
|||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
|
@ -216,7 +220,7 @@ class LayerNormHandler(OperatorHandler):
|
|||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
from webbrowser import Opera
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.fx.node import Node
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
__all__ = ['OperatorHandler']
|
||||
|
@ -60,7 +62,7 @@ class OperatorHandler(ABC):
|
|||
@abstractmethod
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
"""
|
||||
Register
|
||||
Register
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -4,9 +4,9 @@ import warnings
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
import builtins
|
||||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from .constants import *
|
||||
from .op_handler import *
|
||||
from .options import SolverOptions
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .op_handler import *
|
||||
from .constants import *
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import torch
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||
import builtins
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
||||
|
|
|
@ -0,0 +1,275 @@
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
class ModuleWrapper(nn.Module):
|
||||
'''
|
||||
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
|
||||
into the forward function.
|
||||
'''
|
||||
|
||||
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
||||
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
|
||||
'''
|
||||
Args:
|
||||
module: the original module
|
||||
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
|
||||
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
|
||||
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
|
||||
'''
|
||||
super(ModuleWrapper, self).__init__()
|
||||
self.module = module
|
||||
self.sharding_spec_dict = sharding_spec_dict
|
||||
self.origin_spec_dict = origin_spec_dict
|
||||
self.comm_actions_dict = comm_actions_dict
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args,
|
||||
sharding_spec_convert_dict=self.sharding_spec_dict,
|
||||
origin_node_sharding_spec_dict=self.origin_spec_dict,
|
||||
comm_actions_dict=self.comm_actions_dict,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
|
||||
'''
|
||||
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
|
||||
'''
|
||||
# TODO: implement this function
|
||||
pass
|
||||
|
||||
|
||||
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
|
||||
'''
|
||||
This method is used to search the best logical mesh shape for the given world size
|
||||
based on the alpha_beta_dict.
|
||||
|
||||
For example:
|
||||
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
|
||||
'''
|
||||
# TODO: implement this function
|
||||
return (world_size, 1)
|
||||
|
||||
|
||||
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
|
||||
'''
|
||||
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
|
||||
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
|
||||
'''
|
||||
# TODO: implement this function
|
||||
pass
|
||||
|
||||
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
||||
'''
|
||||
This method is used to build the strategy_constructor for the given graph.
|
||||
After this method, each node in the graph will have a strategies_vector which
|
||||
is constructed by the related node handler.
|
||||
'''
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
return strategies_constructor
|
||||
|
||||
|
||||
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
||||
'''
|
||||
This method is used to solve the best solution for the given graph.
|
||||
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
||||
'''
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
|
||||
return solution
|
||||
|
||||
|
||||
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
'''
|
||||
This method is used to transform the original graph to the sharded graph.
|
||||
The model parameters will be sharded according to the solution and the grad hooks
|
||||
will be added to the sharded graph using the runtime_preparation_pass.
|
||||
The communication node will be added into the graph using the runtime_apply_pass.
|
||||
'''
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
gm, solution, device_mesh, strategies_constructor)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
return gm, sharding_spec_dicts
|
||||
|
||||
|
||||
def initialize_device_mesh(world_size: int = -1,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None):
|
||||
'''
|
||||
This method is used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
world_size(optional): the size of device mesh. If the world_size is -1,
|
||||
the world size will be set to the number of GPUs in the current machine.
|
||||
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||
generated by profile_alpha_beta function.
|
||||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
'''
|
||||
# if world_size is not set, use the world size from torch.distributed
|
||||
if world_size == -1:
|
||||
world_size = dist.get_world_size()
|
||||
device1d = [i for i in range(world_size)]
|
||||
|
||||
if alpha_beta_dict is None:
|
||||
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
|
||||
alpha_beta_dict = profile_alpha_beta(device1d)
|
||||
|
||||
if logical_mesh_shape is None:
|
||||
# search for the best logical mesh shape
|
||||
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
|
||||
physical_mesh = torch.tensor(device1d)
|
||||
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
|
||||
mesh_shape=logical_mesh_shape,
|
||||
mesh_alpha=mesh_alpha,
|
||||
mesh_beta=mesh_beta,
|
||||
init_process_group=True)
|
||||
return device_mesh
|
||||
|
||||
|
||||
def initialize_model(model: nn.Module,
|
||||
meta_args: Dict[str, torch.Tensor],
|
||||
device_mesh: DeviceMesh,
|
||||
memory_budget: float = -1.0,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solution_path: str = None,
|
||||
return_solution: bool = False):
|
||||
'''
|
||||
This method is used to initialize the sharded model which could be used as normal pytorch model.
|
||||
|
||||
Args:
|
||||
model: the model to be sharded.
|
||||
meta_args: the meta_args is used to specify the input shapes of the model.
|
||||
device_mesh: the device mesh to execute the model.
|
||||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
from the solution_path.
|
||||
solution_path(optional): the path to save or load the solution.
|
||||
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
|
||||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer()
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||
if load_solver_solution:
|
||||
solution = torch.load(solution_path)
|
||||
else:
|
||||
solution = solve_solution(gm, strategies_constructor, memory_budget)
|
||||
if save_solver_solution:
|
||||
torch.save(solution, solution_path)
|
||||
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
|
||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
if return_solution:
|
||||
solution_to_return = []
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
|
||||
return model_to_return, solution_to_return
|
||||
else:
|
||||
return model_to_return
|
||||
|
||||
|
||||
def autoparallelize(model: nn.Module,
|
||||
meta_args: Dict[str, torch.Tensor] = None,
|
||||
data_loader: torch.utils.data.DataLoader = None,
|
||||
data_process_func: callable = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
return_solution: bool = False,
|
||||
memory_budget: float = -1.0):
|
||||
'''
|
||||
This method is used to initialize the device mesh, extract the meta_args, and
|
||||
use them to create a sharded model.
|
||||
|
||||
Args:
|
||||
model: the model to be sharded.
|
||||
meta_args(optional): the meta_args is used to specify the input shapes of the model.
|
||||
If the meta_args is None, the meta_args will be extracted from the data_loader.
|
||||
data_loader(optional): the data_loader to be used in normal training loop.
|
||||
data_process_func(optional): the data_process_func is used to process the data from the data_loader.
|
||||
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||
generated by profile_alpha_beta function.
|
||||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
from the solution_path.
|
||||
solver_solution_path(optional): the path to save or load the solution.
|
||||
return_solution(optional): if the return_solution is True, the solution will be returned.
|
||||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
'''
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
|
||||
if meta_args is None:
|
||||
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
||||
|
||||
rst_to_unpack = initialize_model(model,
|
||||
meta_args,
|
||||
device_mesh,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solver_solution_path=solver_solution_path,
|
||||
return_solution=return_solution,
|
||||
memory_budget=memory_budget)
|
||||
|
||||
if return_solution:
|
||||
model, solution = rst_to_unpack
|
||||
return model, solution
|
||||
else:
|
||||
model = rst_to_unpack
|
||||
return model
|
|
@ -1,19 +1,31 @@
|
|||
from .addmm_handler import ADDMMFunctionHandler
|
||||
from .batch_norm_handler import BatchNormModuleHandler
|
||||
from .binary_elementwise_handler import BinaryElementwiseHandler
|
||||
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
|
||||
from .experimental import PermuteHandler, ViewHandler
|
||||
from .getattr_handler import GetattrHandler
|
||||
from .getitem_handler import GetItemHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .output_handler import OuputHandler
|
||||
from .placeholder_handler import PlacehodlerHandler
|
||||
from .output_handler import OutputHandler
|
||||
from .placeholder_handler import PlaceholderHandler
|
||||
from .registry import operator_registry
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .softmax_handler import SoftmaxHandler
|
||||
from .sum_handler import SumHandler
|
||||
from .tensor_constructor_handler import TensorConstructorHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from .where_handler import WhereHandler
|
||||
|
||||
__all__ = [
|
||||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'operator_registry'
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
|
||||
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['ADDMMFunctionHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.addmm)
|
||||
@operator_registry.register(torch.Tensor.addmm)
|
||||
class ADDMMFunctionHandler(NodeHandler):
|
||||
"""
|
||||
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
|
||||
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
|
||||
no logical-physical shape conversion in this handler.
|
||||
"""
|
||||
|
||||
def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType:
|
||||
if isinstance(tensor, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
return data_type
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
||||
# input operand
|
||||
input_data = self.node.args[1]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[1]),
|
||||
type=self._infer_op_data_type(input_data),
|
||||
data=input_data)
|
||||
|
||||
# other operand
|
||||
other_data = self.node.args[2]._meta_data
|
||||
physical_other_operand = OperationData(name=str(self.node.args[2]),
|
||||
type=self._infer_op_data_type(other_data),
|
||||
data=other_data)
|
||||
# bias physical shape
|
||||
bias_logical_shape = self.node._meta_data.shape
|
||||
bias_data = self.node.args[0]._meta_data
|
||||
physical_bias_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=self._infer_op_data_type(bias_data),
|
||||
data=bias_data,
|
||||
logical_shape=bias_logical_shape)
|
||||
|
||||
# output
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"other": physical_other_operand,
|
||||
"output": physical_output,
|
||||
'bias': physical_bias_operand
|
||||
}
|
||||
|
||||
return mapping
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(
|
||||
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
|
||||
return generators
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
# convert bias from its logical sharding spec to its physical sharding spec
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
bias_op_data = op_data_mapping['bias']
|
||||
bias_physical_shape = bias_op_data.data.shape
|
||||
bias_logical_shape = bias_op_data.logical_shape
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
|
||||
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
|
||||
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
|
||||
|
||||
if len(removed_dims) > 0:
|
||||
comm_action = comm_actions_for_oprands(node=self.node,
|
||||
removed_dims=removed_dims,
|
||||
op_data=bias_op_data,
|
||||
sharding_spec=bias_sharding_spec)
|
||||
strategy.communication_actions[bias_op_data] = comm_action
|
||||
|
||||
return strategy
|
|
@ -2,8 +2,10 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
|
|||
@operator_registry.register(torch.nn.BatchNorm1d)
|
||||
@operator_registry.register(torch.nn.BatchNorm2d)
|
||||
@operator_registry.register(torch.nn.BatchNorm3d)
|
||||
class BatchNormModuleHandler(ModuleHandler):
|
||||
class BatchNormModuleHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['BinaryElementwiseHandler']
|
||||
|
||||
|
||||
@operator_registry.register(BCAST_FUNC_OP)
|
||||
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||
operands and broadcasting occurs such as torch.add.
|
||||
"""
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
bcast_shape = self.node._meta_data.shape
|
||||
|
||||
def _get_op_data_type(tensor):
|
||||
if isinstance(tensor, torch.nn.parameter.Parameter):
|
||||
return OperationDataType.PARAM
|
||||
else:
|
||||
return OperationDataType.ARG
|
||||
|
||||
def _get_arg_value(idx):
|
||||
if isinstance(self.node.args[idx], Node):
|
||||
meta_data = self.node.args[idx]._meta_data
|
||||
else:
|
||||
# this is in fact a real data like int 1
|
||||
# but we can deem it as meta data
|
||||
# as it won't affect the strategy generation
|
||||
assert isinstance(self.node.args[idx], (int, float))
|
||||
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
||||
return meta_data
|
||||
|
||||
input_meta_data = _get_arg_value(0)
|
||||
other_meta_data = _get_arg_value(1)
|
||||
output_meta_data = self.node._meta_data
|
||||
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=_get_op_data_type(input_meta_data),
|
||||
data=input_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
other_op_data = OperationData(name=str(self.node.args[1]),
|
||||
type=_get_op_data_type(other_meta_data),
|
||||
data=other_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
output_op_data = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=output_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
# convert bias from its logical sharding spec to its physical sharding spec
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
for op_name, op_data in op_data_mapping.items():
|
||||
if not isinstance(op_data.data, torch.Tensor):
|
||||
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
||||
strategy.sharding_specs.pop(op_data)
|
||||
else:
|
||||
# convert the logical sharding spec to physical sharding spec if broadcast
|
||||
# e.g. torch.rand(4, 4) + torch.rand(4)
|
||||
physical_shape = op_data.data.shape
|
||||
logical_shape = op_data.logical_shape
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
sharding_spec, logical_shape, physical_shape)
|
||||
|
||||
strategy.sharding_specs[op_data] = sharding_spec
|
||||
if len(removed_dims) > 0:
|
||||
comm_action = comm_actions_for_oprands(node=self.node,
|
||||
removed_dims=removed_dims,
|
||||
op_data=op_data,
|
||||
sharding_spec=sharding_spec)
|
||||
strategy.communication_actions[op_data] = comm_action
|
||||
|
||||
return strategy
|
|
@ -2,8 +2,10 @@ from typing import Dict, List, Union
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
|
||||
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
|
||||
|
@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
|
|||
bias_physical_shape = bias_op_data.data.shape
|
||||
bias_logical_shape = bias_op_data.logical_shape
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
|
||||
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
|
||||
bias_physical_shape)
|
||||
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
|
||||
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
|
||||
|
||||
if len(removed_dims) > 0:
|
||||
comm_action = comm_actions_for_oprands(node=self.node,
|
||||
removed_dims=removed_dims,
|
||||
op_data=bias_op_data,
|
||||
sharding_spec=bias_sharding_spec)
|
||||
strategy.communication_actions[bias_op_data] = comm_action
|
||||
|
||||
return strategy
|
||||
|
|
|
@ -3,9 +3,9 @@ from typing import Dict, List
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||
from ..utils import transpose_partition_dim
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
|
|||
@operator_registry.register(torch.nn.Conv1d)
|
||||
@operator_registry.register(torch.nn.Conv2d)
|
||||
@operator_registry.register(torch.nn.Conv3d)
|
||||
class ConvModuleHandler(ModuleHandler):
|
||||
class ConvModuleHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
|
||||
"""
|
||||
|
@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
|
|||
@operator_registry.register(F.conv1d)
|
||||
@operator_registry.register(F.conv2d)
|
||||
@operator_registry.register(F.conv3d)
|
||||
class ConvFunctionHandler(NodeHandler):
|
||||
class ConvFunctionHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
|
||||
|
||||
|
||||
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
|
||||
output_name: str) -> List[ShardingStrategy]:
|
||||
"""
|
||||
This function converts the logical sharding spec to the physical sharding spec for both the input and output
|
||||
of the embedding operation.
|
||||
|
||||
Args:
|
||||
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
|
||||
input_name (str): the name of the OperationData object for the input.
|
||||
output_name (str): the name of the OperationData object for the output.
|
||||
"""
|
||||
# the result will be a list of strategies
|
||||
sharding_strategies = []
|
||||
|
||||
# get operation data
|
||||
input_op_data = strategy.get_op_data_by_name(input_name)
|
||||
output_op_data = strategy.get_op_data_by_name(output_name)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# recover the last logical dimension to physical dimension
|
||||
last_logical_output_dims = len(output_op_data.logical_shape) - 1
|
||||
last_physical_output_dims = output_op_data.data.dim() - 1
|
||||
|
||||
# get logger for debug message
|
||||
logger = get_dist_logger()
|
||||
|
||||
# For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for
|
||||
# logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the
|
||||
# physical input shape. Thus, we enumerate to get all possible cases.
|
||||
if input_sharding_spec.dim_partition_dict:
|
||||
# if bool(input_sharding_spec.dim_partition_dict), it means that the
|
||||
# the generated sharding strategy does shard the non-matrix dimension,
|
||||
# in this case, we need to do enumeration
|
||||
num_input_dims = input_op_data.data.dim()
|
||||
for i in range(num_input_dims):
|
||||
strategy_copy = strategy.clone()
|
||||
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
|
||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
|
||||
else:
|
||||
dim_mapping = {0: i}
|
||||
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping=dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
|
||||
strategy_copy.name = f'{strategy.name}_{i}'
|
||||
sharding_strategies.append(strategy_copy)
|
||||
|
||||
except ShardingNotDivisibleError as e:
|
||||
logger.debug(
|
||||
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
|
||||
)
|
||||
else:
|
||||
# the generated sharding strategy does not shard the non-matrix dimension,
|
||||
# in this case, we don't need to do enumeration
|
||||
# but instead, we still need to convert the logical shape to physical shape
|
||||
strategy_copy = strategy.clone()
|
||||
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# after updating, the logical shape will be replaced by the physical shape
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
|
||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
|
||||
else:
|
||||
dim_mapping = {}
|
||||
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping=dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
|
||||
return sharding_strategies
|
||||
|
||||
|
||||
@operator_registry.register(torch.nn.Embedding)
|
||||
class EmbeddingModuleHandler(ModuleHandler):
|
||||
"""
|
||||
A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension,
|
||||
# and then the sharding spec will be generated based on the logical 1D tensor.
|
||||
# After that, the logical sharding info will be enumerated among all the physical dimensions.
|
||||
# Finally, the input will be transformed back to its original shape in self.post_process
|
||||
input_meta_data = self.node.args[0]._meta_data
|
||||
input_logical_shape = input_meta_data.view(-1).shape
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=input_meta_data,
|
||||
logical_shape=input_logical_shape)
|
||||
|
||||
physical_other_operand = OperationData(name="weight",
|
||||
type=OperationDataType.PARAM,
|
||||
data=self.named_parameters['weight'])
|
||||
|
||||
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
|
||||
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
|
||||
# on the logical 2D tensor.
|
||||
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
|
||||
# Finally, the output will be transformed back to its original shape in self.post_process
|
||||
output_meta_data = self.node._meta_data
|
||||
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
|
||||
physical_output = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=output_meta_data,
|
||||
logical_shape=output_logical_shape)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
"""
|
||||
Convert the sharding spec from the logical shape to the physical shape.
|
||||
"""
|
||||
# create multiple sharding strategies for the inputs
|
||||
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
|
||||
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
|
||||
input_name=str(
|
||||
self.node.args[0]),
|
||||
output_name=str(self.node))
|
||||
return strategies
|
||||
|
||||
|
||||
@operator_registry.register(F.embedding)
|
||||
class EmbeddingFunctionHandler(NodeHandler):
|
||||
"""
|
||||
A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# In F.embedding operation, all the dimensions of input will be treated as the batch dimension,
|
||||
# and then the sharding spec will be generated based on the logical 1D tensor.
|
||||
# After that, the logical sharding info will be enumerated among all the physical dimensions.
|
||||
# Finally, the input will be transformed back to its original shape in self.post_process
|
||||
input_meta_data = self.node.args[0]._meta_data
|
||||
input_logical_shape = input_meta_data.view(-1).shape
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data,
|
||||
logical_shape=input_logical_shape)
|
||||
|
||||
# check if the other operand is a parameter
|
||||
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||||
type=data_type,
|
||||
data=self.node.args[1]._meta_data)
|
||||
|
||||
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
|
||||
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
|
||||
# on the logical 2D tensor.
|
||||
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
|
||||
# Finally, the output will be transformed back to its original shape in self.post_process
|
||||
output_meta_data = self.node._meta_data
|
||||
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
|
||||
physical_output = OperationData(
|
||||
name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=self.node._meta_data,
|
||||
logical_shape=output_logical_shape,
|
||||
)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy):
|
||||
"""
|
||||
Convert the sharding spec from the logical shape to the physical shape.
|
||||
"""
|
||||
# create multiple sharding strategies for the inputs
|
||||
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
|
||||
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
|
||||
input_name=str(
|
||||
self.node.args[0]),
|
||||
output_name=str(self.node))
|
||||
return strategies
|
|
@ -0,0 +1,10 @@
|
|||
from .permute_handler import PermuteHandler
|
||||
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
|
||||
from .split_handler import SplitHandler
|
||||
from .transpose_handler import TransposeHandler
|
||||
from .view_handler import ViewHandler
|
||||
|
||||
__all__ = [
|
||||
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
|
||||
'SplitHandler', 'SplitGenerator'
|
||||
]
|
|
@ -0,0 +1,76 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...sharding_strategy import OperationData, OperationDataType
|
||||
from ..node_handler import NodeHandler
|
||||
from ..registry import operator_registry
|
||||
from ..strategy import StrategyGenerator
|
||||
from .reshape_generator import PermuteGenerator
|
||||
|
||||
__all__ = ['PermuteHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
@operator_registry.register(torch.permute)
|
||||
class PermuteHandler(NodeHandler):
|
||||
"""
|
||||
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
|
||||
permute_dims = []
|
||||
if self.node.op == 'call_method':
|
||||
# torch.Tensor.permute (input, *dims)
|
||||
for arg in self.node.args:
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
permute_dims.append(arg._meta_data)
|
||||
else:
|
||||
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
|
||||
permute_dims.append(arg)
|
||||
else:
|
||||
# torch.permute (input, dims)
|
||||
for arg in self.node.args:
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
if isinstance(arg._meta_data, (tuple, list)):
|
||||
permute_dims.extend(arg._meta_data)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg,
|
||||
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
|
||||
permute_dims.extend(arg)
|
||||
|
||||
num_dims = self.node._meta_data.dim()
|
||||
for i in range(num_dims):
|
||||
# recover negative value to positive
|
||||
if permute_dims[i] < 0:
|
||||
permute_dims[i] += num_dims
|
||||
|
||||
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"permute_dims": physical_shape_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
|
@ -0,0 +1,299 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
MemoryCost,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
check_keep_sharding_status,
|
||||
detect_reshape_mapping,
|
||||
infer_output_dim_partition_dict,
|
||||
)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
|
||||
|
||||
|
||||
class ReshapeGenerator(FollowingStrategyGenerator):
|
||||
"""
|
||||
ReshapeGenerator is the base class for all the reshape operation.
|
||||
"""
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy):
|
||||
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
'''
|
||||
forward_size_mapping = {
|
||||
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||
'output': self._compute_size_in_bytes(strategy, "output")
|
||||
}
|
||||
|
||||
backward_size_mapping = copy.deepcopy(forward_size_mapping)
|
||||
backward_size_mapping.pop("output")
|
||||
# compute fwd cost incurred
|
||||
# fwd_cost = input + output
|
||||
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
|
||||
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
|
||||
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
||||
|
||||
# compute bwd cost incurred
|
||||
# bwd_cost = input_grad
|
||||
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
|
||||
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
|
||||
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
|
||||
|
||||
# compute total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
||||
parameter=fwd_parameter_cost + bwd_parameter_cost)
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
return super().collate_strategies()
|
||||
|
||||
|
||||
class ViewGenerator(ReshapeGenerator):
|
||||
"""
|
||||
ViewGenerator deals with the sharding strategies of view op.
|
||||
"""
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
|
||||
origin_shape = self.op_data['input'].data.shape
|
||||
tgt_shape = self.op_data['tgt_shape'].data
|
||||
|
||||
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||
|
||||
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
|
||||
|
||||
if keep_sharding_status:
|
||||
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
|
||||
reshape_mapping_dict)
|
||||
else:
|
||||
dim_partition_dict_for_output = {}
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
if keep_sharding_status:
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||
else:
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
|
||||
|
||||
# add comm action for converting input to fully replicated
|
||||
total_mesh_dim_list = []
|
||||
for mesh_dim_list in dim_partition_dict_for_input.values():
|
||||
total_mesh_dim_list.extend(mesh_dim_list)
|
||||
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
total_mesh_dim_list = total_mesh_dim_list[0]
|
||||
# the total mesh dim list only has one element, so the shard dim has only one element as well.
|
||||
shard_dim = list(dim_partition_dict_for_input.keys())[0]
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=total_mesh_dim_list,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
# it will gather the input through gather_dim during forward phase.
|
||||
input_comm_action.comm_spec.gather_dim = shard_dim
|
||||
# it will split the input activation grad through shard_dim during backward phase.
|
||||
input_comm_action.comm_spec.shard_dim = shard_dim
|
||||
|
||||
elif len(total_mesh_dim_list) >= 2:
|
||||
source_spec = sharding_spec_mapping["input"]
|
||||
target_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=source_spec.entire_shape,
|
||||
dim_partition_dict={})
|
||||
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||
|
||||
else:
|
||||
input_comm_action = None
|
||||
|
||||
if input_comm_action is not None:
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class PermuteGenerator(ReshapeGenerator):
|
||||
"""
|
||||
PermuteGenerator deals with the sharding strategies of permute op.
|
||||
"""
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
|
||||
permute_dims = self.op_data['permute_dims'].data
|
||||
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||
dim_partition_dict_for_output = {}
|
||||
for dim_index, permute_dim in enumerate(permute_dims):
|
||||
if permute_dim in dim_partition_dict_for_input:
|
||||
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class TransposeGenerator(ReshapeGenerator):
|
||||
"""
|
||||
TransposeGenerator deals with the sharding strategies of permute op.
|
||||
"""
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||
dim_partition_dict_for_output = {}
|
||||
|
||||
transpose_dims = self.op_data['transpose_dims'].data
|
||||
dim_0 = transpose_dims[0]
|
||||
dim_1 = transpose_dims[1]
|
||||
for dim, sharded_dims in dim_partition_dict_for_input.items():
|
||||
if dim == dim_0:
|
||||
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
|
||||
elif dim == dim_1:
|
||||
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
|
||||
else:
|
||||
dim_partition_dict_for_output[dim] = sharded_dims
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class SplitGenerator(ReshapeGenerator):
|
||||
"""
|
||||
SplitGenerator deals with the sharding strategies of split op.
|
||||
"""
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
recover_dims = None
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
split_size, split_dim = self.op_data['split_info'].data
|
||||
|
||||
if split_dim in dim_partition_dict_for_input:
|
||||
recover_dims = dim_partition_dict_for_input.pop(split_dim)
|
||||
|
||||
dim_partition_dict_for_output = [
|
||||
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
|
||||
]
|
||||
assert len(dim_partition_dict_for_output) >= 2
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
|
||||
|
||||
# add comm action if the input need to be recovered to replica in the split dimension.
|
||||
if recover_dims:
|
||||
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||
if len(recover_dims) == 1:
|
||||
recover_dims = recover_dims[0]
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=recover_dims,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
# it will gather the input through gather_dim during forward phase.
|
||||
input_comm_action.comm_spec.gather_dim = split_dim
|
||||
# it will split the input activation grad through split_dim during backward phase.
|
||||
input_comm_action.comm_spec.shard_dim = split_dim
|
||||
|
||||
elif len(recover_dims) >= 2:
|
||||
# original sharding spec
|
||||
source_spec = input_sharding_spec
|
||||
# target sharding spec
|
||||
target_spec = sharding_spec_mapping["input"]
|
||||
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||
|
||||
else:
|
||||
input_comm_action = None
|
||||
|
||||
if input_comm_action is not None:
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
|
@ -0,0 +1,63 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...sharding_strategy import OperationData, OperationDataType
|
||||
from ..node_handler import NodeHandler
|
||||
from ..registry import operator_registry
|
||||
from ..strategy import StrategyGenerator
|
||||
from .reshape_generator import SplitGenerator
|
||||
|
||||
__all__ = ['SplitHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.split)
|
||||
@operator_registry.register(torch.split)
|
||||
class SplitHandler(NodeHandler):
|
||||
"""
|
||||
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
split_size = self.node.args[1]
|
||||
if len(self.node.args) == 3:
|
||||
# (input, split_size, split_dim)
|
||||
split_dim = self.node.args[2]
|
||||
else:
|
||||
if self.node.kwargs:
|
||||
split_dim = self.node.kwargs['dim']
|
||||
else:
|
||||
split_dim = 0
|
||||
|
||||
num_dims = self.node.args[0]._meta_data.dim()
|
||||
# recover negative value to positive
|
||||
if split_dim < 0:
|
||||
split_dim += num_dims
|
||||
|
||||
split_info = (split_size, split_dim)
|
||||
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"split_info": physical_shape_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...sharding_strategy import OperationData, OperationDataType
|
||||
from ..node_handler import NodeHandler
|
||||
from ..registry import operator_registry
|
||||
from ..strategy import StrategyGenerator
|
||||
from .reshape_generator import TransposeGenerator
|
||||
|
||||
__all__ = ['TransposeHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.transpose)
|
||||
@operator_registry.register(torch.transpose)
|
||||
class TransposeHandler(NodeHandler):
|
||||
"""
|
||||
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
|
||||
transpose_dims = []
|
||||
# torch.transpose (input, dim0, dim1)
|
||||
for arg in self.node.args:
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
transpose_dims.append(arg._meta_data)
|
||||
else:
|
||||
transpose_dims.append(arg)
|
||||
|
||||
num_dims = self.node._meta_data.dim()
|
||||
for i in range(2):
|
||||
# recover negative value to positive
|
||||
if transpose_dims[i] < 0:
|
||||
transpose_dims[i] += num_dims
|
||||
|
||||
physical_shape_operand = OperationData(name='transpose_dims',
|
||||
type=OperationDataType.ARG,
|
||||
data=list(transpose_dims))
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"transpose_dims": physical_shape_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...sharding_strategy import OperationData, OperationDataType
|
||||
from ..node_handler import NodeHandler
|
||||
from ..registry import operator_registry
|
||||
from ..strategy import StrategyGenerator
|
||||
from .reshape_generator import ViewGenerator
|
||||
|
||||
__all__ = ['ViewHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.reshape)
|
||||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.Tensor.view)
|
||||
class ViewHandler(NodeHandler):
|
||||
"""
|
||||
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
|
||||
target_shape = self.node._meta_data.shape
|
||||
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"tgt_shape": physical_shape_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Dict, List
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .strategy import GetattrGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['GetattrHandler']
|
||||
|
||||
|
||||
class GetattrHandler(NodeHandler):
|
||||
"""
|
||||
A GetattrHandler which deals with the sharding strategies for Getattr Node.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(GetattrGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# There are only two possible types for get_attr node:
|
||||
# 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers)
|
||||
# 2. torch.nn.Module
|
||||
# temporarily, we just support first case in Tracer, so we don't have to worry about
|
||||
# issue related to the node._meta_data type.
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"output": physical_output}
|
||||
|
||||
return mapping
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
|
||||
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
|
||||
|
||||
__all__ = ['GetItemHandler']
|
||||
|
||||
|
|
|
@ -3,12 +3,16 @@ from typing import Dict, List, Union
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
check_sharding_spec_validity,
|
||||
transpose_partition_dim,
|
||||
update_partition_dim,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -28,9 +32,11 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|||
# switch the dimensions of the transposed weight
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
|
||||
op_data = strategy.get_op_data_by_name(weight_name)
|
||||
assert op_data.logical_shape != op_data.data.shape, \
|
||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||
transpose_partition_dim(sharding_spec, 0, -1)
|
||||
assert op_data.logical_shape[0] == op_data.data.shape[1] and \
|
||||
op_data.logical_shape[1] == op_data.data.shape[0], \
|
||||
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
|
||||
dim_size = len(op_data.logical_shape)
|
||||
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
|
||||
return strategy
|
||||
|
||||
|
||||
|
@ -54,6 +60,23 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
input_op_data = strategy.get_op_data_by_name(input_name)
|
||||
output_op_data = strategy.get_op_data_by_name(output_name)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# recover the last logical dimension to physical dimension
|
||||
last_logical_input_dims = len(input_op_data.logical_shape) - 1
|
||||
last_logical_output_dims = len(output_op_data.logical_shape) - 1
|
||||
last_physical_input_dims = input_op_data.data.dim() - 1
|
||||
last_physical_output_dims = output_op_data.data.dim() - 1
|
||||
|
||||
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
|
||||
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
|
||||
else:
|
||||
input_last_dim_mapping = {}
|
||||
|
||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
|
||||
else:
|
||||
output_last_dim_mapping = {}
|
||||
|
||||
# get logger for debug message
|
||||
logger = get_dist_logger()
|
||||
|
@ -73,14 +96,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
||||
input_dim_mapping = {0: i}
|
||||
input_dim_mapping.update(input_last_dim_mapping)
|
||||
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
dim_mapping=input_dim_mapping,
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
output_dim_mapping = {0: i}
|
||||
output_dim_mapping.update(output_last_dim_mapping)
|
||||
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
dim_mapping=output_dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
strategy_copy.name = f'{strategy.name}_{i}'
|
||||
sharding_strategies.append(strategy_copy)
|
||||
except ShardingNotDivisibleError as e:
|
||||
logger.debug(
|
||||
|
@ -95,12 +125,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# after updating, the logical shape will be replaced by the physical shape
|
||||
input_dim_mapping = {}
|
||||
input_dim_mapping.update(input_last_dim_mapping)
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={},
|
||||
dim_mapping=input_dim_mapping,
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
|
||||
output_dim_mapping = {}
|
||||
output_dim_mapping.update(output_last_dim_mapping)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={},
|
||||
dim_mapping=output_dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
|
@ -108,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
|
||||
|
||||
@operator_registry.register(torch.nn.Linear)
|
||||
class LinearModuleHandler(ModuleHandler):
|
||||
class LinearModuleHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
"""
|
||||
|
@ -116,7 +151,8 @@ class LinearModuleHandler(ModuleHandler):
|
|||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
generators.append(
|
||||
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
@ -167,15 +203,16 @@ class LinearModuleHandler(ModuleHandler):
|
|||
|
||||
|
||||
@operator_registry.register(F.linear)
|
||||
class LinearFunctionHandler(NodeHandler):
|
||||
class LinearFunctionHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
generators.append(
|
||||
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
@ -198,27 +235,34 @@ class LinearFunctionHandler(NodeHandler):
|
|||
type=data_type,
|
||||
data=self.node.args[1]._meta_data,
|
||||
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
output_meta_data = self.node._meta_data
|
||||
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
|
||||
physical_output = OperationData(
|
||||
name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=self.node._meta_data,
|
||||
logical_shape=output_logical_shape,
|
||||
)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if self.node.args[2] is not None:
|
||||
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
|
||||
# check if the other operand is a parameter
|
||||
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
|
||||
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
||||
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
|
||||
type=data_type,
|
||||
data=self.node.args[2]._meta_data)
|
||||
data=self.node.kwargs["bias"]._meta_data)
|
||||
mapping['bias'] = physical_bias_operand
|
||||
|
||||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy):
|
||||
# switch the dimensions of the transposed weight
|
||||
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
|
||||
weight_name=str(self.node.args[1]))
|
||||
|
||||
# create multiple sharding strategies for the inputs
|
||||
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
||||
|
|
|
@ -0,0 +1,486 @@
|
|||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
|
||||
BroadcastType,
|
||||
get_broadcast_dim_info,
|
||||
get_broadcast_shape,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (
|
||||
BatchedMatMulStrategyGenerator,
|
||||
DotProductStrategyGenerator,
|
||||
LinearProjectionStrategyGenerator,
|
||||
MatVecStrategyGenerator,
|
||||
StrategyGenerator,
|
||||
)
|
||||
|
||||
|
||||
class MatMulType(Enum):
|
||||
"""
|
||||
The MatMulType is categorized into 4 types based on the reference of torch.matmul
|
||||
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
|
||||
|
||||
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
|
||||
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
|
||||
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
|
||||
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
|
||||
"""
|
||||
DOT = 0
|
||||
MM = 1
|
||||
MV = 2
|
||||
BMM = 3
|
||||
|
||||
|
||||
def get_matmul_type(input_dim: int, other_dim: int):
|
||||
"""
|
||||
Determine which type of matmul operation should be executed for the given tensor dimensions.
|
||||
|
||||
Args:
|
||||
input_dim (int): the number of dimensions for the input tenosr
|
||||
other_dim (int): the number of dimensions for the other tenosr
|
||||
"""
|
||||
if input_dim == 1 and other_dim == 1:
|
||||
matmul_type = MatMulType.DOT
|
||||
elif input_dim in [1, 2] and other_dim == 2:
|
||||
matmul_type = MatMulType.MM
|
||||
elif input_dim == 2 and other_dim == 1:
|
||||
matmul_type = MatMulType.MV
|
||||
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
|
||||
matmul_type = MatMulType.BMM
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
|
||||
)
|
||||
return matmul_type
|
||||
|
||||
|
||||
class BmmTransform(ABC):
|
||||
"""
|
||||
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
|
||||
during the strategy generation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
pass
|
||||
|
||||
|
||||
class Padder(BmmTransform):
|
||||
"""
|
||||
Add padding to the matrix dimensions for batched matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# keep the padding dim, op_name -> padded_dim
|
||||
self.padded_dim_mapping = {}
|
||||
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
mapping_copy = deepcopy(shape_mapping)
|
||||
input_shape = mapping_copy['input']
|
||||
other_shape = mapping_copy['other']
|
||||
|
||||
if len(input_shape) == 1:
|
||||
# if the input is a 1D tensor, 1 is prepended to its shape
|
||||
# and it will be removed afterwards
|
||||
input_shape.insert(0, 1)
|
||||
self.padded_dim_mapping['input'] = -2
|
||||
self.padded_dim_mapping['output'] = -2
|
||||
elif len(other_shape) == 1:
|
||||
# if the other is a 1D tensor, 1 is appended to its shape
|
||||
# and it will be removed afterwards
|
||||
other_shape = other_shape.append(1)
|
||||
self.padded_dim_mapping['other'] = -1
|
||||
self.padded_dim_mapping['output'] = -1
|
||||
return mapping_copy
|
||||
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
input_op_data = op_data_mapping['input']
|
||||
other_op_data = op_data_mapping['other']
|
||||
|
||||
def _remove_padded_dim(key, strategy):
|
||||
op_data = op_data_mapping[key]
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
tensor_shape = list(sharding_spec.entire_shape)
|
||||
dim_partition_list = [None] * len(tensor_shape)
|
||||
|
||||
# padded dim is a negative number as the padded dim must be a matrix dim
|
||||
padded_dim = self.padded_dim_mapping[key]
|
||||
|
||||
# compute the new dim partition
|
||||
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
|
||||
dim_partition_list[tensor_dim] = mesh_dims
|
||||
dim_partition_list.pop(padded_dim)
|
||||
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
|
||||
|
||||
# compute unpadded tensor shape
|
||||
tensor_shape.pop(padded_dim)
|
||||
|
||||
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
|
||||
|
||||
# update sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
|
||||
|
||||
# enumerate all sharding strategies
|
||||
strategies = []
|
||||
try:
|
||||
strategy_copy = strategy.clone()
|
||||
|
||||
# only one of input and other will be padded
|
||||
if 'input' in self.padded_dim_mapping:
|
||||
_remove_padded_dim('input', strategy_copy)
|
||||
_remove_padded_dim('output', strategy_copy)
|
||||
elif 'other' in self.padded_dim_mapping:
|
||||
_remove_padded_dim('other', strategy_copy)
|
||||
_remove_padded_dim('output', strategy_copy)
|
||||
|
||||
strategies.append(strategy_copy)
|
||||
except ShardingSpecException as e:
|
||||
pass
|
||||
return strategies
|
||||
|
||||
|
||||
class Broadcaster(BmmTransform):
|
||||
"""
|
||||
Broadcast the non-matrix dimensions for batched matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.broadcast_dim_info = {}
|
||||
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
mapping_copy = shape_mapping.copy()
|
||||
|
||||
# get shapes
|
||||
input_shape = mapping_copy['input']
|
||||
other_shape = mapping_copy['other']
|
||||
|
||||
# sanity check
|
||||
assert len(input_shape) > 1 and len(other_shape) > 1
|
||||
|
||||
# broadcast the batch dim and record
|
||||
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
|
||||
|
||||
# store the broadcast dim info
|
||||
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
|
||||
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
|
||||
self.broadcast_dim_info['input'] = input_broadcast_dim_info
|
||||
self.broadcast_dim_info['other'] = other_broadcast_dim_info
|
||||
|
||||
# create the full logical shape
|
||||
input_shape = bcast_non_matrix_dims + input_shape[-2:]
|
||||
other_shape = bcast_non_matrix_dims + other_shape[-2:]
|
||||
assert len(input_shape) == len(other_shape)
|
||||
|
||||
mapping_copy['input'] = input_shape
|
||||
mapping_copy['other'] = other_shape
|
||||
|
||||
return mapping_copy
|
||||
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
# remove sharding on the broadcast dim
|
||||
def _remove_sharding_on_broadcast_dim(key, strategy):
|
||||
op_data = op_data_mapping[key]
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
tensor_shape = list(sharding_spec.entire_shape)
|
||||
|
||||
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
|
||||
if broadcast_type == BroadcastType.MULTIPLE:
|
||||
# if the dim is originally 1 and multiplied during broadcast
|
||||
# we set its sharding to R
|
||||
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
|
||||
# the dim 0 of [1, 2, 4] is multiplied to 4
|
||||
tensor_shape[dim_idx] = 1
|
||||
elif broadcast_type == BroadcastType.PADDDING:
|
||||
# if the dim is padded
|
||||
# we remove its sharding
|
||||
tensor_shape[dim_idx] = None
|
||||
|
||||
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
|
||||
|
||||
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
logical_sharding_spec=sharding_spec,
|
||||
logical_shape=sharding_spec.entire_shape,
|
||||
physical_shape=tensor_shape_before_broadcast)
|
||||
strategy.sharding_specs[op_data] = physical_sharding_spec
|
||||
|
||||
# enumerate all sharding strategies
|
||||
strategies = []
|
||||
try:
|
||||
strategy_copy = strategy.clone()
|
||||
_remove_sharding_on_broadcast_dim('input', strategy_copy)
|
||||
_remove_sharding_on_broadcast_dim('other', strategy_copy)
|
||||
strategies.append(strategy_copy)
|
||||
except ShardingSpecException as e:
|
||||
pass
|
||||
return strategies
|
||||
|
||||
|
||||
class Viewer(BmmTransform):
|
||||
"""
|
||||
Change the shape of the tensor from N-D to 3D
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.batch_dims_before_view = None
|
||||
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
mapping_copy = shape_mapping.copy()
|
||||
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
|
||||
|
||||
# get shapes
|
||||
input_shape = shape_mapping['input']
|
||||
other_shape = shape_mapping['other']
|
||||
|
||||
# view to 3d tensor
|
||||
assert len(input_shape) >= 3 and len(other_shape) >= 3
|
||||
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
|
||||
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
|
||||
output_shape = input_shape[:2] + other_shape[2:]
|
||||
mapping_copy['input'] = input_shape
|
||||
mapping_copy['other'] = other_shape
|
||||
mapping_copy['output'] = output_shape
|
||||
return mapping_copy
|
||||
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
# get operation data
|
||||
def _update_sharding_spec(key, strategy, physical_batch_dim):
|
||||
"""
|
||||
Map the logical batch dim to the physical batch dim
|
||||
"""
|
||||
op_data = op_data_mapping[key]
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
entire_shape = sharding_spec.entire_shape
|
||||
|
||||
# upddate the dimension index for the matrix dimensions
|
||||
if 2 in dim_partition_dict:
|
||||
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
|
||||
if 1 in dim_partition_dict:
|
||||
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
|
||||
|
||||
# map the logical batch dim to phyiscal batch dim
|
||||
if 0 in dim_partition_dict:
|
||||
batch_dim_shard = dim_partition_dict.pop(0)
|
||||
dim_partition_dict[physical_batch_dim] = batch_dim_shard
|
||||
|
||||
# the new shape will be the batch dims + the last 2 matrix dims
|
||||
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
|
||||
|
||||
num_batch_dim_before_view = len(self.batch_dims_before_view)
|
||||
|
||||
# enumerate all sharding strategies
|
||||
strategies = []
|
||||
for i in range(num_batch_dim_before_view):
|
||||
# create a new strategy
|
||||
strategy_copy = strategy.clone()
|
||||
try:
|
||||
_update_sharding_spec('input', strategy_copy, i)
|
||||
_update_sharding_spec('other', strategy_copy, i)
|
||||
_update_sharding_spec('output', strategy_copy, i)
|
||||
strategies.append(strategy_copy)
|
||||
except ShardingSpecException as e:
|
||||
continue
|
||||
return strategies
|
||||
|
||||
|
||||
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
|
||||
"""
|
||||
Compute the logical shapes for BMM operation. BMM has a general representation
|
||||
[b, i, k] = [b, i, j] x [b, j, k]
|
||||
|
||||
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
|
||||
The logical shape for the bmm operands will undergo three stages
|
||||
1. append/prepend the 1 to the 1D tensor if there is any
|
||||
2. broadcast the non-matrix dimensions
|
||||
3. reshape to 3 dimensions
|
||||
|
||||
"""
|
||||
shape_mapping = {'input': input_shape, 'other': other_shape}
|
||||
|
||||
for transform in transforms:
|
||||
shape_mapping = transform.apply(shape_mapping)
|
||||
|
||||
input_shape = shape_mapping.get('input', None)
|
||||
other_shape = shape_mapping.get('other', None)
|
||||
output_shape = shape_mapping.get('output', None)
|
||||
|
||||
return input_shape, other_shape, output_shape
|
||||
|
||||
|
||||
@operator_registry.register(torch.matmul)
|
||||
@operator_registry.register(torch.Tensor.matmul)
|
||||
class MatMulHandler(NodeHandler):
|
||||
"""
|
||||
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
||||
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
||||
the operands.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# check which type of operation this matmul will call
|
||||
self.input_meta_data = self.node.args[0]._meta_data
|
||||
self.other_meta_data = self.node.args[1]._meta_data
|
||||
self.output_meta_data = self.node._meta_data
|
||||
|
||||
input_dim = self.input_meta_data.dim()
|
||||
other_dim = self.other_meta_data.dim()
|
||||
self.matmul_type = get_matmul_type(input_dim, other_dim)
|
||||
|
||||
if self.matmul_type == MatMulType.BMM:
|
||||
# bmm operation can possibly involve padding, broadcasting and view
|
||||
# these transforms will be used to create logical shape and
|
||||
# recover physical sharding spec
|
||||
self.transforms = [Padder(), Broadcaster(), Viewer()]
|
||||
else:
|
||||
self.transforms = None
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
generators = []
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
if self.matmul_type == MatMulType.BMM:
|
||||
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
elif self.matmul_type == MatMulType.DOT:
|
||||
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
elif self.matmul_type == MatMulType.MV:
|
||||
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
elif self.matmul_type == MatMulType.MM:
|
||||
generators.append(
|
||||
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
logical_shape_func = {
|
||||
MatMulType.DOT: self._get_logical_shape_for_dot,
|
||||
MatMulType.MM: self._get_logical_shape_for_mm,
|
||||
MatMulType.MV: self._get_logical_shape_for_mv,
|
||||
MatMulType.BMM: self._get_logical_shape_for_bmm
|
||||
}
|
||||
logical_shapes = logical_shape_func[self.matmul_type]()
|
||||
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
|
||||
return op_data_mapping
|
||||
|
||||
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
|
||||
# convert list to torch.Size
|
||||
if input_logical_shape:
|
||||
input_logical_shape = torch.Size(input_logical_shape)
|
||||
|
||||
if other_logical_shape:
|
||||
other_logical_shape = torch.Size(other_logical_shape)
|
||||
|
||||
if output_logical_shape:
|
||||
output_logical_shape = torch.Size(output_logical_shape)
|
||||
|
||||
# create op data
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.input_meta_data,
|
||||
logical_shape=input_logical_shape)
|
||||
other_op_data = OperationData(name=str(self.node.args[1]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.other_meta_data,
|
||||
logical_shape=other_logical_shape)
|
||||
output_op_data = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=self.output_meta_data,
|
||||
logical_shape=output_logical_shape)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
|
||||
def _get_logical_shape_for_dot(self):
|
||||
"""
|
||||
The operands for the dot operation have the same logical shape as the physical shape
|
||||
"""
|
||||
return None, None, None
|
||||
|
||||
def _get_logical_shape_for_mm(self):
|
||||
"""
|
||||
We need to handle the input tensor for a matrix-matrix multiplcation as the input
|
||||
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
|
||||
(e.g. [4] -> [1, 4]).
|
||||
"""
|
||||
if self.input_meta_data.dim() == 1:
|
||||
input_logical_shape = [1] + list(self.input_meta_data.shape)
|
||||
input_logical_shape = torch.Size(input_logical_shape)
|
||||
else:
|
||||
input_logical_shape = None
|
||||
return input_logical_shape, None, None
|
||||
|
||||
def _get_logical_shape_for_mv(self):
|
||||
"""
|
||||
No broadcasting or dim insertion occurs for matrix-vector operation.
|
||||
"""
|
||||
return None, None, None
|
||||
|
||||
def _get_logical_shape_for_bmm(self):
|
||||
input_physical_shape = list(self.input_meta_data.shape)
|
||||
other_physical_shape = list(self.other_meta_data.shape)
|
||||
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
|
||||
return strategy
|
||||
elif self.matmul_type == MatMulType.MM:
|
||||
if self.input_meta_data.dim() == 1:
|
||||
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
|
||||
# we need to remove that dim
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
|
||||
input_physical_shape = self.node.args[0]._meta_data.shape
|
||||
dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
|
||||
# remove the partitioning in the dim 0
|
||||
if 0 in dim_partition_dict:
|
||||
dim_partition_dict.pop(0, None)
|
||||
|
||||
# move the partitioning in dim 1 to dim 0
|
||||
if -1 in dim_partition_dict:
|
||||
shard = dim_partition_dict.pop(-1)
|
||||
dim_partition_dict[0] = shard
|
||||
if 1 in dim_partition_dict:
|
||||
shard = dim_partition_dict.pop(1)
|
||||
dim_partition_dict[0] = shard
|
||||
|
||||
# re-init the sharding spec
|
||||
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
|
||||
entire_shape=input_physical_shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return strategy
|
||||
else:
|
||||
return strategy
|
||||
elif self.matmul_type == MatMulType.BMM:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
strategies = [strategy]
|
||||
# recover the physical sharding spec
|
||||
for transform in self.transforms[::-1]:
|
||||
recovered_stragies = []
|
||||
for strategy_ in strategies:
|
||||
output = transform.recover(op_data_mapping, strategy_)
|
||||
if isinstance(output, ShardingStrategy):
|
||||
recovered_stragies.append(output)
|
||||
elif isinstance(output, (list, tuple)):
|
||||
recovered_stragies.extend(output)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
|
||||
strategies = recovered_stragies
|
||||
return strategies
|
|
@ -1,11 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingSpec,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
|
@ -49,7 +52,16 @@ class NodeHandler(ABC):
|
|||
|
||||
for node in self.predecessor_node:
|
||||
node_name = str(node)
|
||||
# get the current sharding spec generated by this node handler
|
||||
|
||||
# we will not compute the resharding costs for the node not counted in the strategy.
|
||||
# And the node with tuple or list output need to be handled below.
|
||||
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
|
||||
if str(node) not in node_in_strategy:
|
||||
continue
|
||||
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||
# get the sharding specs for this node generated
|
||||
# in its own node handler
|
||||
assert hasattr(node, 'strategies_vector'), \
|
||||
|
@ -59,27 +71,83 @@ class NodeHandler(ABC):
|
|||
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
||||
]
|
||||
|
||||
# get the current sharding spec generated by this node handler
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||
|
||||
# create data structrure to store costs
|
||||
if op_data not in resharding_costs:
|
||||
if node not in resharding_costs:
|
||||
resharding_costs[node] = []
|
||||
|
||||
def _compute_resharding_cost(
|
||||
prev_sharding_spec: Union[ShardingSpec,
|
||||
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
|
||||
List[ShardingSpec]],
|
||||
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
|
||||
"""
|
||||
This is a helper function to compute the resharding cost for a specific strategy of a node.
|
||||
"""
|
||||
if prev_sharding_spec is None:
|
||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
elif isinstance(prev_sharding_spec, ShardingSpec):
|
||||
if isinstance(data, torch.Tensor):
|
||||
dtype = data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
|
||||
prev_sharding_spec, current_sharding_spec)
|
||||
|
||||
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
|
||||
bwd=consistency_cost["backward"] * size_per_elem_bytes,
|
||||
total=consistency_cost["total"] * size_per_elem_bytes)
|
||||
return resharding_cost
|
||||
else:
|
||||
# This raise is used to check if we have missed any type of data.
|
||||
# It could be merged into Parameter branch, which means we won't handle
|
||||
# non-tensor arguments.
|
||||
raise ValueError(f'Unsupported data type {type(data)}')
|
||||
else:
|
||||
assert isinstance(prev_sharding_spec, (tuple, list)), \
|
||||
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
|
||||
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
|
||||
|
||||
fwd_cost = 0
|
||||
bwd_cost = 0
|
||||
total_cost = 0
|
||||
for index, (prev_sharding_spec_item,
|
||||
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
|
||||
current_sharding_spec)):
|
||||
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
|
||||
data[index])
|
||||
fwd_cost += item_cost.fwd
|
||||
bwd_cost += item_cost.bwd
|
||||
total_cost += item_cost.total
|
||||
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
|
||||
return resharding_cost
|
||||
|
||||
# for each sharding spec generated by the predecessor's node handler
|
||||
# compute the resharding cost to switch to the sharding spec generated
|
||||
# by the current node handler
|
||||
for prev_sharding_spec in prev_sharding_specs:
|
||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec,
|
||||
current_sharding_spec)
|
||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
|
||||
bwd=resharding_cost["backward"],
|
||||
total=resharding_cost["total"])
|
||||
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
|
||||
resharding_costs[node].append(resharding_cost)
|
||||
strategy.resharding_costs = resharding_costs
|
||||
return strategy
|
||||
|
||||
def get_target_function(self) -> callable:
|
||||
"""
|
||||
This function is used to get the target function for the node handler.
|
||||
The target function is used to analyze the costs of strategies.
|
||||
"""
|
||||
if self.node.op in ('placeholder', 'get_attr', 'output'):
|
||||
return None
|
||||
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
elif self.node.op == 'call_function':
|
||||
target = self.node.target
|
||||
elif self.node.op == 'call_method':
|
||||
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
|
||||
else:
|
||||
raise ValueError(f'Unsupported node type: {self.node.op}')
|
||||
|
||||
return target
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
Register different sharding strategies for the current node.
|
||||
|
@ -151,6 +219,38 @@ class NodeHandler(ABC):
|
|||
pass
|
||||
|
||||
|
||||
class MetaInfoNodeHandler(NodeHandler):
|
||||
"""
|
||||
This is a base class to handle the nodes patched in the meta profiler.
|
||||
|
||||
Note: this class will be integrated into the NodeHandler class in the future, after
|
||||
all the functions are patched.
|
||||
"""
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
class ModuleHandler(NodeHandler):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
@ -168,3 +268,35 @@ class ModuleHandler(NodeHandler):
|
|||
self.module = module
|
||||
self.named_parameters = named_parameters
|
||||
self.named_buffers = named_buffers
|
||||
|
||||
|
||||
class MetaInfoModuleHandler(ModuleHandler):
|
||||
"""
|
||||
This is a base class to handle the module patched in the meta profiler.
|
||||
|
||||
Note: this class will be integrated into the ModuleHandler class in the future, after
|
||||
all the modules are patched.
|
||||
"""
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
return self.strategies_vector
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
|||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
|
|||
@operator_registry.register(torch.nn.AvgPool1d)
|
||||
@operator_registry.register(torch.nn.AvgPool2d)
|
||||
@operator_registry.register(torch.nn.AvgPool3d)
|
||||
class NormPoolingHandler(ModuleHandler):
|
||||
class NormPoolingHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
|
||||
"""
|
||||
|
|
|
@ -2,38 +2,51 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import NodeHandler
|
||||
from .strategy import OutputGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['OuputHandler']
|
||||
__all__ = ['OutputHandler']
|
||||
|
||||
|
||||
class OuputHandler(NodeHandler):
|
||||
class OutputHandler(NodeHandler):
|
||||
"""
|
||||
A OuputHandler which deals with the sharding strategies for Output Node.
|
||||
A OutputHandler which deals with the sharding strategies for Output Node.
|
||||
"""
|
||||
|
||||
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
output_option: str) -> None:
|
||||
super().__init__(node, device_mesh, strategies_vector)
|
||||
self.output_option = output_option
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node))
|
||||
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
dummy_output = torch.empty(1,).to("meta")
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output)
|
||||
|
||||
mapping = {"output": physical_output}
|
||||
mapping = {}
|
||||
output_meta_data = []
|
||||
for index, input_node in enumerate(self.predecessor_node):
|
||||
if not hasattr(input_node, "_meta_data"):
|
||||
print(input_node.name)
|
||||
physical_inputs = OperationData(name=str(input_node),
|
||||
type=OperationDataType.ARG,
|
||||
data=input_node._meta_data)
|
||||
input_meta_data = input_node._meta_data
|
||||
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
|
||||
name_key = f'input_{index}'
|
||||
mapping[name_key] = physical_inputs
|
||||
output_meta_data.append(input_meta_data)
|
||||
|
||||
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
|
||||
if len(output_meta_data) == 1:
|
||||
output_meta_data = output_meta_data[0]
|
||||
else:
|
||||
output_meta_data = tuple(output_meta_data)
|
||||
|
||||
self.node._meta_data = output_meta_data
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping["output"] = physical_output
|
||||
return mapping
|
||||
|
|
|
@ -1,21 +1,31 @@
|
|||
from typing import Dict, List
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import NodeHandler
|
||||
from .strategy import PlaceholderGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['PlacehodlerHandler']
|
||||
__all__ = ['PlaceholderHandler']
|
||||
|
||||
|
||||
class PlacehodlerHandler(NodeHandler):
|
||||
class PlaceholderHandler(NodeHandler):
|
||||
"""
|
||||
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
|
||||
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
placeholder_option: str) -> None:
|
||||
super().__init__(node, device_mesh, strategies_vector)
|
||||
self.placeholder_option = placeholder_option
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh))
|
||||
generators.append(
|
||||
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
|
|
@ -8,7 +8,12 @@ class Registry:
|
|||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
self.store[source] = func
|
||||
if isinstance(source, (list, tuple)):
|
||||
# support register a list of items for this func
|
||||
for element in source:
|
||||
self.store[element] = func
|
||||
else:
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -3,18 +3,17 @@ from typing import Dict, List
|
|||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ReshapeGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['ReshapeHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
@operator_registry.register(torch.Tensor.unsqueeze)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
class ReshapeHandler(NodeHandler):
|
||||
class ReshapeHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||
"""
|
||||
|
@ -25,13 +24,47 @@ class ReshapeHandler(NodeHandler):
|
|||
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def infer_logical_shape(self, data):
|
||||
"""
|
||||
This function is used to infer logical shape for operands.
|
||||
|
||||
Notes: This function is only used for the operands whose data are not only in type of tensor,
|
||||
such as tuple of tensor.
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.shape
|
||||
else:
|
||||
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
|
||||
logical_shape = []
|
||||
for tensor in data:
|
||||
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
|
||||
logical_shape.append(tensor.shape)
|
||||
logical_shape = tuple(logical_shape)
|
||||
return logical_shape
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
input_logical_shape = self.infer_logical_shape(input_data)
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
type=data_type,
|
||||
data=input_data,
|
||||
logical_shape=input_logical_shape)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
output_logical_shape = self.infer_logical_shape(output_data)
|
||||
physical_output = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=output_data,
|
||||
logical_shape=output_logical_shape)
|
||||
|
||||
mapping = {"input": physical_input_operand, "output": physical_output}
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import SoftmaxGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['SoftmaxHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.nn.Softmax)
|
||||
@operator_registry.register(torch.nn.functional.softmax)
|
||||
class SoftmaxHandler(NodeHandler):
|
||||
"""
|
||||
A SoftmaxHandler which deals with the sharding strategies for
|
||||
torch.nn.Softmax or torch.nn.functional.softmax.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
|
||||
softmax_dim = self.node.kwargs['dim']
|
||||
|
||||
num_dims = self.node.args[0]._meta_data.dim()
|
||||
# recover negative value to positive
|
||||
if softmax_dim < 0:
|
||||
softmax_dim += num_dims
|
||||
|
||||
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"softmax_dim": physical_dim_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue